Skip to content

Conversation

realAsma
Copy link
Contributor

@realAsma realAsma commented Sep 15, 2025

What does this PR do?

Type of change: ? new tests, QATTrainer workflow fixes and simplification

Overview:

  1. ModelOpt entry points now accepts all distributed wrapped models (previously we did not allow ddp/FSDP to ModelOpt entry points. We support all wrapper after this PR making ModelOpt workflows simpler).

  2. Fixed QATTrainer FSDP2 workflow disruption and unblocked QLoRA FSDP2: Previously the ModelOpt states were restored after the weights were loaded for FSDP2. This broke ModelOpt workflow which required the ModelOpt states to be restored before weights are loaded. This workflow disruption made FSDP2 flow incompatible with QLoRA. The workflow is fixed in this PR.

  3. This PR makes several improvements to QATTrainer and training workflow:
    i. Simplified QATTrainer workflow from user-side (removed eval_only argument).
    ii. Cleaned up QATTrainer to work with various backends such as ddp, fsdp, fsdp2, DeepSpeed seamlessly.
    iii. Added unit tests for llm_qat with various backends.
    iv. Removed examples/llm_qat/convert_sharded_ckpt.py -

Usage

See example/llm_qat/main.py

Testing

See the updated unit tests

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: To Do
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: This will be done in a future PR
  • Did you update Changelog?: Not needed

Additional Information

Summary by CodeRabbit

  • New Features

    • Backend selection (fsdp1/fsdp2/ddp/deepspeed) with DeepSpeed accel config, mixed-precision options, quantizer state save/restore, new perplexity metric, training now logs "Training completed." and evaluation uses evaluate().
  • Refactor

    • Unified launch script argument parsing and backend-driven startup; improved wrapper handling and FS-DP2 compatibility; deprecation warnings for legacy flags.
  • Tests

    • Tests parameterized over backends and improved distributed port handling.
  • Chores

    • Removed automatic post-launch sharded-checkpoint conversion; minor README cleanup.

@realAsma realAsma requested review from a team as code owners September 15, 2025 16:04
@realAsma realAsma requested review from mxinO and ChenhanYu September 15, 2025 16:04
Copy link

copy-pr-bot bot commented Sep 15, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/quantization/utils.py (1)

459-466: Symmetric fix for restore path.

Use the model‑aware name to match saved keys across wrapper changes.

Apply this diff:

-    for name, module in model.named_modules():
-        if isinstance(module, TensorQuantizer) and get_unwrapped_name(name) in quantizer_state_dict:
-            module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name)])
+    for name, module in model.named_modules():
+        if isinstance(module, TensorQuantizer):
+            key = get_unwrapped_name(name, model)
+            if key in quantizer_state_dict:
+                module.load_state_dict(quantizer_state_dict[key])
examples/llm_qat/launch.sh (1)

126-129: Division by zero when no GPUs detected.

DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) fails when GPU_COUNT=0 (CPU‑only CI/dev). Guard it.

Apply this diff:

-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
+GPU_COUNT=$(python -c "import torch; n=torch.cuda.device_count(); print(n if n>0 else 1)")
+# Calculate save_steps (fallback to 1 on CPU-only)
+DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
🧹 Nitpick comments (10)
modelopt/torch/quantization/nn/modules/quant_module.py (1)

161-164: Make context manager re-entrant-safe by restoring prior state

If quantize_weight() is nested or called when _enable_weight_quantization was already True, hard‑resetting to False on exit can flip the state incorrectly. Preserve and restore the previous value.

     def quantize_weight(self):
         """Context in which `self.weight` is quantized."""
-        self._enable_weight_quantization = True
-        try:
-            yield
-        finally:
-            self._enable_weight_quantization = False
+        prev = getattr(self, "_enable_weight_quantization", False)
+        self._enable_weight_quantization = True
+        try:
+            yield
+        finally:
+            self._enable_weight_quantization = prev
examples/llm_qat/utils.py (1)

172-175: Perplexity value can be overwritten by existing key; compute-first merge reverses precedence

With {"perplexity": ..., **metrics}, any preexisting metrics["perplexity"] overwrites the computed value. Flip the precedence (or assign in-place) and add a small guard.

-def get_metrics_with_perplexity(metrics):
-    """Add perplexity to the metrics."""
-    metrics = {"perplexity": float(torch.exp(torch.tensor(metrics["eval_loss"]))), **metrics}
-    return metrics
+def get_metrics_with_perplexity(metrics):
+    """Add perplexity to the metrics."""
+    loss = metrics.get("eval_loss", None)
+    if loss is not None:
+        # math.exp avoids unnecessary torch tensor creation on CPU
+        import math
+        metrics["perplexity"] = float(math.exp(float(loss)))
+    return metrics
tests/_test_utils/examples/run_command.py (1)

35-44: Also set a sane default MASTER_ADDR when injecting MASTER_PORT

Some launchers expect MASTER_ADDR; defaulting it to localhost avoids env-dependent failures. Keeping allocation/race risks low is fine for tests.

 def run_example_command(cmd_parts: list[str], example_path: str, setup_free_port: bool = False):
     print(f"[{example_path}] Running command: {cmd_parts}")
     env = os.environ.copy()
 
     if setup_free_port:
         free_port = get_free_port()
         env["MASTER_PORT"] = str(free_port)
+        env.setdefault("MASTER_ADDR", "127.0.0.1")
 
     subprocess.run(cmd_parts, cwd=MODELOPT_ROOT / "examples" / example_path, env=env, check=True)
examples/llm_qat/main.py (1)

266-269: Guard eval-only printing against missing eval_loss

If trainer.evaluate() doesn’t return eval_loss, get_metrics_with_perplexity will no-op after the fix; keeping the current order is fine. Consider asserting presence during tests.

Would you like a small unit test to assert the presence/shape of evaluation metrics across backends?

modelopt/torch/utils/network.py (2)

440-454: Use isinstance for unwrapping to handle subclasses

type(model) in SUPPORTED_WRAPPERS misses subclasses. Iterate with isinstance for robustness.

-    if force_unwrap:
-        try:
-            if type(model) in SUPPORTED_WRAPPERS:
-                return getattr(model, SUPPORTED_WRAPPERS[type(model)])
+    if force_unwrap:
+        try:
+            for wrapper_type, attr in SUPPORTED_WRAPPERS.items():
+                if isinstance(model, wrapper_type):
+                    return getattr(model, attr)
         except AttributeError:
             raise ValueError(
                 f"Model of type {type(model)} could not be forcefully unwrapped! Please manually"
                 " unwrap the model before passing it in."
             )
 
-    if type(model) in SUPPORTED_WRAPPERS:
+    for wrapper_type, attr in SUPPORTED_WRAPPERS.items():
+        if isinstance(model, wrapper_type):
             if raise_error:
                 raise ValueError(msg or f"Model {model} is wrapped by {type(model)}!")
             elif warn:
                 warnings.warn(msg or f"Model {model} is wrapped by {type(model)}; unwrapping...")
-        return getattr(model, SUPPORTED_WRAPPERS[type(model)])
+            return getattr(model, attr)
     return model

599-612: Also strip DataParallel’s 'module.' prefix in get_unwrapped_name

DP inserts the same prefix as DDP; include it in the check.

-    if isinstance(model, nn.parallel.DistributedDataParallel) or (
+    if isinstance(model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)) or (
         DeepSpeedEngine is not None and isinstance(model, DeepSpeedEngine)
     ):
         name = name.removeprefix("module.")
modelopt/torch/opt/plugins/peft.py (1)

84-95: Avoid KeyError and unify with utils: use set_quantizer_state_dict.

Direct indexing quantizer_state_dict[get_unwrapped_name(name)] can KeyError on naming mismatches. Prefer the helper which tolerates missing keys.

Apply this diff:

-    if os.path.isfile(_get_quantizer_state_save_path(model_id)):
-        from modelopt.torch.quantization.nn import TensorQuantizer
-
-        quantizer_state_dict = torch.load(
-            _get_quantizer_state_save_path(model_id), map_location="cpu", weights_only=False
-        )
-        for name, module in self.named_modules():
-            if isinstance(module, TensorQuantizer):
-                module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name)])
+    if os.path.isfile(_get_quantizer_state_save_path(model_id)):
+        from modelopt.torch.quantization.utils import set_quantizer_state_dict
+        quantizer_state_dict = torch.load(
+            _get_quantizer_state_save_path(model_id), map_location="cpu", weights_only=False
+        )
+        set_quantizer_state_dict(self, quantizer_state_dict)
examples/llm_qat/simple_qat_train.py (1)

124-125: Resolve quant config from mtq.config to avoid AttributeError.

Safer to fetch from mtq.config (choices originate there).

Apply this diff:

-    model = mtq.quantize(model, getattr(mtq, args.quant_cfg), calibrate)
+    cfg = getattr(mtq.config, args.quant_cfg)
+    model = mtq.quantize(model, cfg, calibrate)
tests/examples/llm_qat/test_llm_qat.py (1)

39-45: Parametrization includes deepspeed — gate if DS isn’t installed.

Consider conditionally skipping the deepspeed case when import deepspeed fails to avoid infra‑dependent failures in CI.

modelopt/torch/quantization/plugins/transformers_trainer.py (1)

134-143: Quant config resolution: prefer mtq.config with fallback.

Some configs live under mtq.config. Use that first, then fall back to root for re‑exports.

Apply this diff:

-        if quant_args is not None and getattr(quant_args, "quant_cfg", None):
-            quant_cfg = (
-                getattr(mtq, quant_args.quant_cfg)
-                if isinstance(quant_args.quant_cfg, str)
-                else quant_args.quant_cfg
-            )
+        if quant_args is not None and getattr(quant_args, "quant_cfg", None):
+            if isinstance(quant_args.quant_cfg, str):
+                quant_cfg = getattr(getattr(mtq, "config", mtq), quant_args.quant_cfg)
+            else:
+                quant_cfg = quant_args.quant_cfg
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 76e8ce2 and 78f3aaf.

📒 Files selected for processing (14)
  • examples/llm_qat/accelerate_config/deepspeed.yaml (1 hunks)
  • examples/llm_qat/convert_sharded_ckpt.py (0 hunks)
  • examples/llm_qat/launch.sh (4 hunks)
  • examples/llm_qat/main.py (2 hunks)
  • examples/llm_qat/simple_qat_train.py (2 hunks)
  • examples/llm_qat/utils.py (1 hunks)
  • modelopt/torch/opt/conversion.py (1 hunks)
  • modelopt/torch/opt/plugins/peft.py (1 hunks)
  • modelopt/torch/quantization/nn/modules/quant_module.py (1 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (5 hunks)
  • modelopt/torch/quantization/utils.py (2 hunks)
  • modelopt/torch/utils/network.py (3 hunks)
  • tests/_test_utils/examples/run_command.py (1 hunks)
  • tests/examples/llm_qat/test_llm_qat.py (3 hunks)
💤 Files with no reviewable changes (1)
  • examples/llm_qat/convert_sharded_ckpt.py
🧰 Additional context used
🧬 Code graph analysis (6)
modelopt/torch/opt/conversion.py (1)
modelopt/torch/utils/network.py (1)
  • unwrap_model (430-454)
examples/llm_qat/main.py (3)
examples/llm_qat/utils.py (1)
  • get_metrics_with_perplexity (172-175)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
  • evaluate (240-247)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/quantization/utils.py (3)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • TensorQuantizer (62-1182)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/opt/plugins/peft.py (1)
modelopt/torch/quantization/utils.py (1)
  • get_quantizer_state_dict (446-456)
tests/examples/llm_qat/test_llm_qat.py (1)
tests/examples/conftest.py (1)
  • tiny_llama_path (33-41)
modelopt/torch/quantization/plugins/transformers_trainer.py (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • TensorQuantizer (62-1182)
  • collect (1166-1176)
modelopt/torch/quantization/utils.py (4)
  • calibrate_with_adapters (275-286)
  • get_quantizer_state_dict (446-456)
  • is_quantized (239-243)
  • set_quantizer_state_dict (459-465)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/opt/conversion.py (3)
  • modelopt_state (444-486)
  • save (489-511)
  • restore_from_modelopt_state (514-576)
modelopt/torch/quantization/model_quant.py (2)
  • forward_loop (95-96)
  • quantize (132-227)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (10)
examples/llm_qat/main.py (2)

41-41: Import relocation LGTM

Switching get_metrics_with_perplexity to the local utils keeps training code decoupled from trainer internals.


263-263: Nice: rank‑0 “Training completed.” log

Good UX for multi-rank runs.

modelopt/torch/utils/network.py (2)

73-77: Wrapper table extension LGTM

Adding FSDP support in SUPPORTED_WRAPPERS is correct and unblocks unwrap/state‑dict flows across backends.


79-86: Conditional DeepSpeed support LGTM

Dynamic import and map extension are appropriate; no hard dependency.

modelopt/torch/opt/conversion.py (2)

383-384: Force-unwrapping semantics changed — verify wrapper coverage and failure mode.

Switching to unwrap_model(..., force_unwrap=True) skips warnings/errors and may raise if a supported wrapper’s expected attribute is absent. Please verify SUPPORTED_WRAPPERS reliably covers DDP, FSDP/FSDP2, and DeepSpeed in your test matrix to avoid surprise ValueErrors when attributes drift.


469-471: LGTM: explicit force‑unwrap for state capture.

Mirrors the apply path and helps avoid wrapper leakage into saved state.

modelopt/torch/opt/plugins/peft.py (1)

60-63: Centralized quantizer snapshot is the right call.

Using get_quantizer_state_dict(self) removes duplication and avoids model.state_dict() pitfalls under FSDP.

examples/llm_qat/simple_qat_train.py (1)

90-93: Confirm config symbol source.

Defaulting to string is fine; ensure NVFP4_DEFAULT_CFG (and other choices) are exported at mtq.config or re-exported at mtq. If not, the resolution below will matter.

tests/examples/llm_qat/test_llm_qat.py (1)

36-37: Good call setting up a free port.

Reduces DDP/Accelerate flakiness in CI.

modelopt/torch/quantization/plugins/transformers_trainer.py (1)

261-293: Accelerate FSDP2 patch looks sound.

Hiding quantizer buffers during prepare avoids FS(D)P2’s “all buffers must be sharded” assumption. Good containment and restoration.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
examples/llm_qat/launch.sh (3)

63-66: Guard divide-by-zero and clamp save_steps to ≥1.

torch.cuda.device_count() can be 0 (or Python/Torch may be unavailable), causing division by zero. Also when GPU_COUNT > 192, integer division yields 0.

-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
+GPU_COUNT=$(python - <<'PY'
+try:
+    import torch
+    print(torch.cuda.device_count())
+except Exception:
+    print(0)
+PY
+)
+# Fallbacks and clamps
+[[ "$GPU_COUNT" =~ ^[0-9]+$ ]] || GPU_COUNT=1
+(( GPU_COUNT > 0 )) || GPU_COUNT=1
+# Calculate save_steps per GPU, ensure at least 1
+DEFAULT_SAVE_STEPS=$(( 192 / GPU_COUNT ))
+(( DEFAULT_SAVE_STEPS > 0 )) || DEFAULT_SAVE_STEPS=1

88-92: Quote-safe and unset-safe QUANT_CFG check.

Unquoted -z $QUANT_CFG can error or misbehave when unset or with spaces/hyphens.

-if [ -z $QUANT_CFG ]; then
-  QUANT_ARGS=""
-else
-  QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE"
-fi
+if [ -n "${QUANT_CFG:-}" ]; then
+  QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE"
+else
+  QUANT_ARGS=""
+fi

95-97: Unset/empty-safe MAX_STEPS handling.

Same quoting problem here.

-if [ ! -z $MAX_STEPS ]; then
+if [ -n "${MAX_STEPS:-}" ]; then
   OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS"
 fi
🧹 Nitpick comments (4)
examples/llm_qat/launch.sh (4)

21-26: Add a portable to_lower helper (avoid Bash 4+ dependency).

The script uses ${var,,} later, which breaks on macOS’ Bash 3.2. Add a tiny helper here and use it where you need lowercase comparisons.

Add this function just below parse_value:

to_lower() { printf '%s' "${1:-}" | tr '[:upper:]' '[:lower:]'; }

52-55: Fix invalid-argument error message.

Currently prints only the substring after “=”, which is confusing. Print the full flag.

-      >&2 printf "Error: Invalid argument ${1#*=}\n"
+      >&2 printf "Error: Invalid argument '%s'\n" "$1"

99-108: Lowercasing requires Bash 4+; use helper and warn when overriding backend due to --compress.

${var,,} breaks on macOS Bash. Also, silently forcing DDP when --compress True can surprise users—emit an info message.

Apply both changes (uses the to_lower helper suggested above):

-if [[ "${USE_FSDP2,,}" == "true" ]]; then
+if [[ "$(to_lower "${USE_FSDP2:-}")" == "true" ]]; then
   echo "Warning: --use_fsdp2 is deprecated. Use --backend=fsdp2 instead."
   BACKEND="fsdp2"
 fi

-# if compress is true, set backend to ddp
-if [[ "${COMPRESS,,}" == "true" ]]; then
-  BACKEND="ddp"
+# if compress is true, set backend to ddp
+if [[ "$(to_lower "${COMPRESS:-}")" == "true" ]]; then
+  if [[ "$(to_lower "${BACKEND:-}")" != "ddp" ]]; then
+    echo "Info: --compress True forces --backend=ddp (overriding '$BACKEND')."
+  fi
+  BACKEND="ddp"
 fi

139-142: Distillation backend check: avoid ${var,,}.

Use the helper for portability.

-  if [[ "${BACKEND,,}" == "fsdp1" || "${BACKEND,,}" == "fsdp2" ]]; then
+  if [[ "$(to_lower "${BACKEND:-}")" == "fsdp1" || "$(to_lower "${BACKEND:-}")" == "fsdp2" ]]; then
     FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
   fi
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 78f3aaf and a11ace9.

📒 Files selected for processing (1)
  • examples/llm_qat/launch.sh (4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-docs
🔇 Additional comments (5)
examples/llm_qat/launch.sh (5)

21-25: parse_value helper looks good.

Compact and correct handling of both --flag=value and --flag value.


175-176: Confirm hard-coded gradient checkpointing.

It’s always enabled. Please confirm it’s intended across all backends/configs.


86-86: Default BACKEND is reasonable.

Defaulting to fsdp1 keeps current behavior while allowing overrides.


145-177: Use an args array for the accelerate command; verify deepspeed gradient_checkpointing

  • Replace the CMD string in examples/llm_qat/launch.sh (lines 145–177) with an args array and invoke "${args[@]}" to avoid shell-quoting and path/spacing pitfalls.
  • Confirm examples/llm_qat/accelerate_config/deepspeed.yaml does not set gradient_checkpointing that conflicts with --gradient_checkpointing True — the automated search returned no match; verify manually.

111-133: Use helper to lowercase BACKEND and add fsdp alias to error message.

File: examples/llm_qat/launch.sh Lines: 111-133 — replace bash-only ${BACKEND,,} with the repo to_lower helper for portability; verified accelerate configs present.

-case "${BACKEND,,}" in
+case "$(to_lower "${BACKEND:-}")" in
   "fsdp1"|"fsdp")
     CONFIG_FILE="fsdp1.yaml"
     FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
     ;;
   "fsdp2")
     echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
     CONFIG_FILE="fsdp2.yaml"
     FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
     ;;
   "ddp")
     CONFIG_FILE="ddp.yaml"
     FSDP_ARGS=""
     ;;
   "deepspeed")
     CONFIG_FILE="deepspeed.yaml"
     FSDP_ARGS=""
     ;;
   *)
-    echo "Error: Invalid backend '$BACKEND'. Supported backends: fsdp1, fsdp2, ddp, deepspeed"
+    echo "Error: Invalid backend '$BACKEND'. Supported backends: fsdp (alias fsdp1), fsdp2, ddp, deepspeed"
     exit 1
     ;;
 esac

Ensure to_lower is defined in the repo (or keep ${BACKEND,,} only if the script intentionally requires bash ≥4).

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)

103-115: Fix: handle QuantizeConfig vs dict in check_awq_smoothquant()

quant_cfg may be a QuantizeConfig instance (not dict); calling .get() will raise AttributeError. Support both types.

Apply this diff:

-def check_awq_smoothquant(quant_cfg):
+def check_awq_smoothquant(quant_cfg):
     # TODO: Remove this once deepspeed for AWQ and SmoothQuant is added
     """Get the quantization type from the configuration."""
     if quant_cfg is None:
         return False
-    algorithm = quant_cfg.get("algorithm", {})
+    # Accept dict-like or QuantizeConfig
+    if isinstance(quant_cfg, dict):
+        algorithm = quant_cfg.get("algorithm", {})
+    else:
+        # QuantizeConfig or object with attribute
+        algorithm = getattr(quant_cfg, "algorithm", {}) or {}
     is_awq_smoothquant = False
     # Check SmoothQuant and AWQ
     if algorithm and ("smoothquant" in algorithm or "awq" in algorithm):
         is_awq_smoothquant = True
 
     return is_awq_smoothquant
♻️ Duplicate comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)

197-201: Bug: dataset can be None → TypeError on len(dataset)

If both train_dataset and eval_dataset are None, this crashes. Provide a clear error and compute length after selection.

Apply this diff:

-        dataset = self.train_dataset if self.train_dataset is not None else self.eval_dataset
-        num_samples = min(self.quant_args.calib_size, len(dataset))  # type: ignore [union-attr]
+        dataset = self.eval_dataset if self.eval_dataset is not None else self.train_dataset
+        assert dataset is not None, "Calibration requires either eval or train dataset."
+        num_samples = min(self.quant_args.calib_size, len(dataset))  # type: ignore [arg-type]
         dataset = torch.utils.data.Subset(dataset, list(range(num_samples)))
         data_loader = self.get_eval_dataloader(dataset)
🧹 Nitpick comments (8)
examples/llm_qat/accelerate_config/fsdp1.yaml (1)

7-7: Activation checkpointing enabled: verify perf/compat across backends and toolchain versions.

Turning this on reduces memory but increases recompute. Please:

  • Confirm torch/accelerate versions in CI support fsdp_activation_checkpointing with fsdp_use_orig_params: true and FSDP v1.
  • Watch test/runtime timeouts; QAT runs may slow notably.
  • Sanity‑check numerics and resume-from-checkpoint with this setting on.

Optionally gate via a CLI flag in launch.sh so users can toggle per run.

modelopt/torch/quantization/plugins/transformers_trainer.py (7)

163-169: Idempotency: avoid double‑patching accelerate.prepare

If init is called multiple times or subclasses re‑invoke the patch, _original_prepare could be overwritten. Guard the patch.

Apply this diff:

-        self._patch_accelerate_for_fsdp2_fix()
+        # Patch once
+        if not hasattr(self.accelerator, "_original_prepare"):
+            self._patch_accelerate_for_fsdp2_fix()

170-189: Distributed save ordering: add post‑save barrier

You barrier before saving, but not after. Add a post‑save barrier to ensure all ranks see a fully written file before proceeding.

Apply this diff:

-        if self.args.should_save:
-            torch.save(modelopt_full_state, self._modelopt_state_path)
+        if self.args.should_save:
+            torch.save(modelopt_full_state, self._modelopt_state_path)
+        if torch.distributed.is_initialized():
+            torch.distributed.barrier()

190-194: Robust load: add map_location to avoid device mismatches

torch.load without map_location can fail on CPU‑only envs or pick wrong GPU. Load to CPU then let restore helpers move tensors.

Apply this diff:

-        modelopt_full_state = torch.load(self._modelopt_state_path, weights_only=False)
+        modelopt_full_state = torch.load(
+            self._modelopt_state_path, map_location="cpu", weights_only=False
+        )

202-208: Clarify intent: use self.model by design; fix misleading comment/param

Per team learning, forward pass should invoke self.model, not the unwrapped model parameter. Rename the unused parameter and update the comment to prevent future regressions.

Apply this diff:

-        def forward_loop(model):
+        def forward_loop(_unused_unwrapped_model):
             for batch in tqdm(data_loader, desc="Calibrating"):
                 batch = self._prepare_inputs(batch)
-                # Important: We should forward pass using the unwrapped model
-                # mtq.quantize will unwrap the model pass the unwrapped model to the forward_loop
-                self.model(**batch)
+                # Important: Forward with self.model to retain wrapper hooks (DDP/FSDP2/DeepSpeed).
+                # Do not use the unwrapped model parameter passed by mtq.quantize().
+                self.model(**batch)

Note: Using the retrieved learning for this file; keeping self.model is intentional.


203-203: Reduce multi‑rank tqdm spam during calibration

Disable progress bars on non‑main processes.

Apply this diff:

-            for batch in tqdm(data_loader, desc="Calibrating"):
+            for batch in tqdm(
+                data_loader,
+                desc="Calibrating",
+                disable=not self.accelerator.is_local_main_process,
+            ):

319-321: Device placement: prefer accelerator.device over .cuda()

Calling .cuda() can select the wrong device under multi‑GPU/Accelerate. Use to(self.accelerator.device).

Apply this diff:

-        self.model.cuda()
+        self.model.to(self.accelerator.device)

352-354: Avoid stringly‑typed state_dict_type checks

Relying on "SHARDED_STATE_DICT" substring is brittle. Prefer comparing against the enum/type provided by the FSDP plugin if available (e.g., FSDPStateDictType.SHARDED_STATE_DICT).

If importing the enum is not feasible here, at least gate on an attribute, e.g. hasattr(self.accelerator.state.fsdp_plugin, "state_dict_type") and compare to known constants from that module.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a11ace9 and f529d43.

📒 Files selected for processing (3)
  • examples/llm_qat/accelerate_config/fsdp1.yaml (1 hunks)
  • examples/llm_qat/launch.sh (4 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/llm_qat/launch.sh
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)

261-293: Quantizer buffers: verify to_empty()/restore sequence is safe

tq.to_empty() frees buffer storage; after prepare you only restore _non_persistent_buffers_set, not buffer contents/devices. Confirm that later calibration/restore repopulates buffers correctly for all TensorQuantizer modules under FSDP2.

Would you run a quick check across backends (FSDP2, DDP) that prints dtype/device and shapes of a few representative quantizer buffers before and after prepare, and post‑calibration, to ensure they’re valid?

@NVIDIA NVIDIA deleted a comment from coderabbitai bot Sep 15, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/llm_qat/accelerate_config/fsdp1.yaml (1)

12-17: Fix mis-keyed FSDP option: FULL_SHARD belongs to sharding strategy, not reshard_after_forward.

fsdp_reshard_after_forward expects a boolean. FULL_SHARD should be set on fsdp_sharding_strategy. This mis-key will be ignored or cause unexpected behavior.

Apply this diff:

   fsdp_forward_prefetch: false
-  fsdp_offload_params: false
-  fsdp_reshard_after_forward: FULL_SHARD
+  fsdp_offload_params: false
+  fsdp_sharding_strategy: FULL_SHARD
+  fsdp_reshard_after_forward: true
   fsdp_state_dict_type: FULL_STATE_DICT
examples/llm_qat/launch.sh (1)

63-66: Bug: division by zero when no GPUs.

GPU_COUNT can be 0 on CPU‑only runners; DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) will fail.

-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
+GPU_COUNT=$(python - <<'PY'
+import os
+try:
+    import torch
+    n = torch.cuda.device_count()
+except Exception:
+    n = 0
+print(max(1, int(n)))
+PY
+)
+# Calculate save_steps (fallback to 192 when CPU-only)
+DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
♻️ Duplicate comments (5)
examples/llm_qat/accelerate_config/fsdp1.yaml (1)

23-23: num_processes: gpu — OK per earlier context; verify CI/runtime Accelerate version.

Per your note, this passes through to torchrun. Please ensure the deployed Accelerate version supports this token across all backends to avoid surprises on older runners.

examples/llm_qat/accelerate_config/deepspeed.yaml (1)

17-17: num_processes: gpu — acknowledged as intentional.

Keeping as-is per earlier discussion; just ensure runners use the compatible Accelerate version.

modelopt/torch/quantization/utils.py (1)

28-29: Right call: use get_unwrapped_name(name, model) to stabilize keys across wrappers.

modelopt/torch/quantization/plugins/transformers_trainer.py (2)

202-208: Forward pass via self.model is intentional — approved.

Per your note, use self.model(**batch) instead of the unwrapped parameter to respect Trainer hooks.


197-201: Fix None dataset path in calibration.

If both train and eval datasets are None, len(dataset) raises. Select dataset first and assert it exists.

-        dataset = self.train_dataset if self.train_dataset is not None else self.eval_dataset
-        num_samples = min(self.quant_args.calib_size, len(dataset))  # type: ignore [union-attr]
+        dataset = self.eval_dataset if self.eval_dataset is not None else self.train_dataset
+        assert dataset is not None, "Calibration requires either eval or train dataset."
+        num_samples = min(self.quant_args.calib_size, len(dataset))
🧹 Nitpick comments (18)
modelopt/torch/quantization/calib/histogram.py (1)

160-163: Broaden warning to all distributed wrappers (not just DDP).

Message still ends with “DDP modules,” which re-introduces wrapper specificity. Recommend generic phrasing.

-                " method is to use the same calibration dataset across all distributed data"
-                " parallel groups so that `amax` is the same for all DDP modules."
+                " method is to use the same calibration dataset across all distributed data"
+                " parallel groups so that `amax` is the same across all ranks/modules."
modelopt/torch/quantization/nn/modules/quant_module.py (1)

161-165: Good fix: ensures cleanup even on exceptions.

The try/finally wrapper guarantees _enable_weight_quantization is reset. Consider a nested-safe counter (increment/decrement) if nested contexts are possible, otherwise this is fine.

examples/llm_qat/utils.py (1)

172-175: Avoid tensor creation; guard missing/overflow for perplexity.

Use math.exp on the float, handle missing eval_loss, and prevent overflow to keep this helper robust and cheap.

-def get_metrics_with_perplexity(metrics):
-    """Add perplexity to the metrics."""
-    metrics = {"perplexity": float(torch.exp(torch.tensor(metrics["eval_loss"]))), **metrics}
-    return metrics
+from math import exp
+
+def get_metrics_with_perplexity(metrics):
+    """Add perplexity from eval_loss if present; fall back gracefully."""
+    loss = metrics.get("eval_loss")
+    if loss is None:
+        return metrics
+    try:
+        ppl = float(exp(float(loss)))
+    except OverflowError:
+        ppl = float("inf")
+    return {**metrics, "perplexity": ppl}
tests/_test_utils/examples/run_command.py (1)

35-44: Set MASTER_ADDR alongside MASTER_PORT for deterministic local runs.

Helps avoid cross-env surprises when a global address is pre-set or missing.

 def run_example_command(cmd_parts: list[str], example_path: str, setup_free_port: bool = False):
     print(f"[{example_path}] Running command: {cmd_parts}")
     env = os.environ.copy()

     if setup_free_port:
         free_port = get_free_port()
         env["MASTER_PORT"] = str(free_port)
+        env.setdefault("MASTER_ADDR", "127.0.0.1")

     subprocess.run(cmd_parts, cwd=MODELOPT_ROOT / "examples" / example_path, env=env, check=True)
modelopt/torch/utils/network.py (1)

439-447: Improve error context when force_unwrap fails.

Include the expected attribute name in the error for faster diagnosis.

-        except AttributeError:
+        except AttributeError:
             raise ValueError(
-                f"Model of type {type(model)} could not be forcefully unwrapped! Please manually"
-                " unwrap the model before passing it in."
+                f"Model of type {type(model)} could not be forcefully unwrapped "
+                f"(missing attr '{SUPPORTED_WRAPPERS.get(type(model), 'module')}'). "
+                "Please unwrap the model before passing it in."
             )
modelopt/torch/quantization/conversion.py (1)

115-123: Avoid recomputing quantizer_state(model) twice.

Store it once to reduce traversal work on large models.

-    quantizer_state_dict = metadata["quantizer_state"]
-    unmatched_keys = quantizer_state_dict.keys() - quantizer_state(model).keys()
-    extra_keys = quantizer_state(model).keys() - quantizer_state_dict.keys()
+    quantizer_state_dict = metadata["quantizer_state"]
+    current_keys = quantizer_state(model).keys()
+    unmatched_keys = quantizer_state_dict.keys() - current_keys
+    extra_keys = current_keys - quantizer_state_dict.keys()
examples/llm_qat/main.py (2)

189-191: Clean stray characters in comment.

There’s a trailing “åå”.

-    # Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file.åå
+    # Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file.

266-269: Guard perplexity computation when eval_loss is absent.

Prevents KeyError on custom metrics.

-        metrics = trainer.evaluate()
-        metrics = get_metrics_with_perplexity(metrics)
-        print_rank_0(f"Evaluation results: \n{metrics}")
+        metrics = trainer.evaluate()
+        if "eval_loss" in metrics:
+            metrics = get_metrics_with_perplexity(metrics)
+        print_rank_0(f"Evaluation results: \n{metrics}")
modelopt/torch/opt/plugins/peft.py (1)

85-95: Pre-validate saved vs. current quantizer keys for clearer errors.

Direct indexing can raise a bare KeyError on mismatches. Add a short key-set check (mirrors conversion.restore_quantizer_state) for better diagnostics.

     if os.path.isfile(_get_quantizer_state_save_path(model_id)):
         from modelopt.torch.quantization.nn import TensorQuantizer
 
         quantizer_state_dict = torch.load(
             _get_quantizer_state_save_path(model_id), map_location="cpu", weights_only=False
         )
+        # Validate keys before loading for clearer messaging
+        expected = {
+            get_unwrapped_name(n, self)
+            for n, m in self.named_modules()
+            if isinstance(m, TensorQuantizer)
+        }
+        got = set(quantizer_state_dict.keys())
+        missing = expected - got
+        extra = got - expected
+        if missing or extra:
+            raise ValueError(f"Quantizer state key mismatch. missing={missing}, extra={extra}")
         for name, module in self.named_modules():
             if isinstance(module, TensorQuantizer):
-                module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name, self)])
+                key = get_unwrapped_name(name, self)
+                module.load_state_dict(quantizer_state_dict[key])
modelopt/torch/opt/conversion.py (1)

590-592: Harden restore semantics: reject wrapped models consistently.

restore() now asserts unwrapped input; restore_from_modelopt_state() does not. To avoid accidental misuse via direct calls, mirror the same check there.

 def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any]) -> nn.Module:
     """Restore the model architecture from the modelopt state dictionary based on the user-provided model."""
     # initialize ModelLikeModule if needed.
-    model = model if isinstance(model, nn.Module) else ModelLikeModule(model)
+    model = model if isinstance(model, nn.Module) else ModelLikeModule(model)
+    # Keep behavior consistent with `restore()`: do not allow wrapped models here either.
+    from modelopt.torch.utils import unwrap_model
+    model = unwrap_model(model, raise_error=True)
tests/examples/llm_qat/test_llm_qat.py (2)

39-45: Param sweep over backends is great; add env‑conditional skips for optional backends.

CI often lacks DeepSpeed (and sometimes FS*2). Guard the tests to skip when the backend isn’t available instead of failing.

 @pytest.mark.parametrize("backend", [
     "fsdp1",
     "fsdp2",
     "deepspeed",
     "ddp",
 ])
 def test_llama_qat_int4w_int8a(tiny_llama_path, tmp_path, backend):
+    if backend == "deepspeed":
+        pytest.importorskip("deepspeed")
+    if backend == "fsdp2":
+        torch = pytest.importorskip("torch")
+        from packaging.version import Version
+        if Version(torch.__version__) < Version("2.3"):
+            pytest.skip("FSDP2 requires torch>=2.3")

Repeat the small guard in test_llama_qat_int4w_int8a_direct_qat.


76-86: Direct QAT test: consider marking slow to keep CI wall‑time sane.

Mark as slow if your CI budget is tight.

-@pytest.mark.parametrize("backend", [
+@pytest.mark.slow
+@pytest.mark.parametrize("backend", [
modelopt/torch/quantization/utils.py (1)

459-466: Consider non‑strict load for forward/backward compat.

Older/newer TensorQuantizer shapes/keys may drift; strict=False reduces fragility while _load_from_state_dict still handles specifics.

-            module.load_state_dict(quantizer_state_dict[key])
+            module.load_state_dict(quantizer_state_dict[key], strict=False)

If you expect strict matching across versions, ignore this.

examples/llm_qat/launch.sh (2)

88-93: Quote QUANT_CFG in test to avoid word‑splitting.

Minor robustness fix.

-if [ -z $QUANT_CFG ]; then
+if [ -z "${QUANT_CFG}" ]; then
   QUANT_ARGS=""
 else
   QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE"
 fi

53-55: Error message prints the value, not the invalid flag.

Use $1 instead of ${1#*=} for clarity.

-      >&2 printf "Error: Invalid argument ${1#*=}\n"
+      >&2 printf "Error: Invalid argument %s\n" "$1"
modelopt/torch/quantization/plugins/transformers_trainer.py (3)

170-186: Potentially ineffective filter for modelopt_state_dict.

state is a (mode_str, state_dict) tuple; checking "kd_loss" in state won’t inspect metadata. The filter likely keeps everything.

-        modelopt_state["modelopt_state_dict"] = [
-            state
-            for state in modelopt_state["modelopt_state_dict"]
-            if "kd_loss" not in state and "export_student" not in state
-        ]
+        filtered = []
+        for m_str, m_state in modelopt_state["modelopt_state_dict"]:
+            meta = m_state.get("metadata", {})
+            if m_str in {"distill", "kd_loss"} or meta.get("export_student", False):
+                continue
+            filtered.append((m_str, m_state))
+        modelopt_state["modelopt_state_dict"] = filtered

Adjust the mode names if your registry uses different identifiers.


240-247: FSDP2 eval‑only hack looks correct; tiny nit: avoid creating opt on empty params.

If the model has no parameters (rare), next(self.model.parameters()) raises. Optional safeguard:

-            dummy_optimizer = torch.optim.SGD([next(self.model.parameters())], lr=0.0)
+            first_param = next(self.model.parameters(), None)
+            dummy_optimizer = torch.optim.SGD([first_param] if first_param is not None else [torch.nn.Parameter(torch.zeros(1, device=self.accelerator.device))], lr=0.0)

261-293: Accelerate FSDP2 patch is clever; leave breadcrumbs for future updates.

You’re touching private attrs (_non_persistent_buffers_set); add a one‑liner warning so future Accelerate upgrades are audited.

         def _modelopt_prepare(self, *args, **kwargs):
+            # NOTE: Relies on private Accelerate/FSDP internals; re‑verify if Accelerate/FSDP2 is updated.
             if not self.is_fsdp2:
                 return self._original_prepare(*args, **kwargs)

Please confirm the targeted Accelerate version(s) in CI where this patch is validated.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 76e8ce2 and 6941586.

📒 Files selected for processing (18)
  • examples/llm_qat/accelerate_config/deepspeed.yaml (1 hunks)
  • examples/llm_qat/accelerate_config/fsdp1.yaml (1 hunks)
  • examples/llm_qat/convert_sharded_ckpt.py (0 hunks)
  • examples/llm_qat/launch.sh (4 hunks)
  • examples/llm_qat/main.py (2 hunks)
  • examples/llm_qat/simple_qat_train.py (2 hunks)
  • examples/llm_qat/utils.py (1 hunks)
  • modelopt/torch/opt/conversion.py (2 hunks)
  • modelopt/torch/opt/dynamic.py (1 hunks)
  • modelopt/torch/opt/plugins/peft.py (2 hunks)
  • modelopt/torch/quantization/calib/histogram.py (1 hunks)
  • modelopt/torch/quantization/conversion.py (2 hunks)
  • modelopt/torch/quantization/nn/modules/quant_module.py (1 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (5 hunks)
  • modelopt/torch/quantization/utils.py (2 hunks)
  • modelopt/torch/utils/network.py (3 hunks)
  • tests/_test_utils/examples/run_command.py (1 hunks)
  • tests/examples/llm_qat/test_llm_qat.py (3 hunks)
💤 Files with no reviewable changes (1)
  • examples/llm_qat/convert_sharded_ckpt.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T16:36:42.871Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: examples/llm_qat/accelerate_config/deepspeed.yaml:17-17
Timestamp: 2025-09-15T16:36:42.871Z
Learning: In Hugging Face Accelerate configuration YAML files, num_processes can accept the string "gpu" as a value, which gets passed through to torch run under the hood and functions correctly, despite standard documentation showing integer values.

Applied to files:

  • examples/llm_qat/accelerate_config/deepspeed.yaml
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (8)
modelopt/torch/quantization/conversion.py (3)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • set_from_modelopt_state (1122-1140)
  • get_modelopt_state (1105-1120)
  • get_modelopt_state (1246-1248)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/opt/plugins/peft.py (2)
modelopt/torch/quantization/utils.py (1)
  • get_quantizer_state_dict (446-456)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/opt/dynamic.py (1)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
examples/llm_qat/main.py (3)
examples/llm_qat/utils.py (1)
  • get_metrics_with_perplexity (172-175)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
  • evaluate (240-247)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/quantization/utils.py (2)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • TensorQuantizer (62-1182)
modelopt/torch/opt/conversion.py (1)
modelopt/torch/utils/network.py (1)
  • unwrap_model (430-454)
tests/examples/llm_qat/test_llm_qat.py (1)
tests/examples/conftest.py (1)
  • tiny_llama_path (33-41)
modelopt/torch/quantization/plugins/transformers_trainer.py (6)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • TensorQuantizer (62-1182)
  • collect (1166-1176)
modelopt/torch/quantization/utils.py (5)
  • calibrate_with_adapters (275-286)
  • disable_lora_quantizers_in_config (289-296)
  • get_quantizer_state_dict (446-456)
  • is_quantized (239-243)
  • set_quantizer_state_dict (459-466)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/opt/conversion.py (3)
  • modelopt_state (444-486)
  • save (489-507)
  • restore_from_modelopt_state (510-567)
modelopt/torch/quantization/model_quant.py (3)
  • forward_loop (95-96)
  • quantize (132-227)
  • print_quant_summary (463-470)
modelopt/torch/distill/plugins/huggingface.py (1)
  • save_model (48-92)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (13)
modelopt/torch/utils/network.py (1)

73-86: Wrapper matrix looks good; DeepSpeed inclusion is correct.

Nice to see deepspeed handled conditionally.

If you rely on is_parallel() for routing, consider adding FSDP/DeepSpeed there in a follow-up.

examples/llm_qat/simple_qat_train.py (1)

90-92: Harden runtime config lookup for --quant-cfg

Guard the getattr and raise a clear ValueError listing valid choices; the provided verification script produced no output so I could not confirm that every entry in mtq.config.choices is defined on modelopt.torch.quantization — re-run the check or verify manually.

Location: examples/llm_qat/simple_qat_train.py lines 90–92 (also 124–125)

-    model = mtq.quantize(model, getattr(mtq, args.quant_cfg), calibrate)
+    cfg = getattr(mtq, args.quant_cfg, None)
+    if cfg is None:
+        raise ValueError(
+            f"Unknown quant config '{args.quant_cfg}'. Valid choices: {mtq.config.choices}"
+        )
+    model = mtq.quantize(model, cfg, calibrate)
examples/llm_qat/accelerate_config/fsdp1.yaml (1)

7-7: Enabling activation checkpointing: looks good.

Change is appropriate for memory savings under FSDP. No further action.

modelopt/torch/quantization/conversion.py (2)

124-133: Wrapper-aware naming in restore path — correct.

Passing model to get_unwrapped_name for both TensorQuantizer and QuantModule matches the updated API and fixes wrapper-prefix mismatches (DDP/FSDP/DeepSpeed).


169-172: Consistent keying for saved metadata — good.

Using get_unwrapped_name(n, model) ensures saved metadata lines up with restore under wrappers.

examples/llm_qat/accelerate_config/deepspeed.yaml (1)

1-24: DeepSpeed Accelerate config — sensible defaults.

Zero-3 with no offload and bf16 is consistent with the training flow.

Please confirm DS/Accelerate versions in CI match these fields (zero3_init_flag, zero3_save_16bit_model, offload_*_device: none) to avoid schema drift across versions.

modelopt/torch/opt/plugins/peft.py (1)

59-65: DRY improvement: centralize quantizer state collection — good.

Replacing ad-hoc collection with get_quantizer_state_dict(self) reduces duplication and handles FSDP safely.

modelopt/torch/opt/conversion.py (1)

383-384: Early force‑unwrap in apply_mode looks correct.

Using force_unwrap=True here prevents nested wrapper state from leaking into ModelOpt state. No issues.

Please confirm DeepSpeedEngine is included in SUPPORTED_WRAPPERS with the correct attribute (usually "module") so force_unwrap doesn’t throw for DS.

tests/examples/llm_qat/test_llm_qat.py (1)

36-37: Nice: free port setup prevents DDP/FS port clashes in CI.*

modelopt/torch/quantization/utils.py (1)

446-456: Quantizer state export avoids FSDP hangs — good.

Implementation is lean and avoids model.state_dict(); keying via get_unwrapped_name is correct.

examples/llm_qat/launch.sh (1)

99-136: Backend routing LGTM; compress→ddp fallback is a sensible default.

Please ensure accelerate_config/{fsdp1,fsdp2,ddp,deepspeed}.yaml ship in wheels/sdists so tests don’t fail when installed from package.

modelopt/torch/quantization/plugins/transformers_trainer.py (2)

172-175: Barrier placement is fine; just ensure every rank calls this method.

Since _save_modelopt_state_with_weights() is invoked on all ranks, the barrier won’t deadlock. Keep it that way.


249-259: Good: switch to FULL_STATE_DICT for final save when not in train.

@realAsma realAsma changed the title [1/N] Trainer improvements: QATTrainer training workflow fixes and clean up; Added backend specific unitests; [1/N] QATTrainer training workflow fixes and clean up; Added backend specific unitests; Sep 15, 2025
@NVIDIA NVIDIA deleted a comment from coderabbitai bot Sep 15, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
modelopt/torch/utils/network.py (1)

599-613: DP prefix handling added (addresses prior feedback); minor robustness nit.

Nice: now strips module. for DataParallel too. Consider stripping repeated prefixes to handle nested wrapping chains.

Apply this diff:

-    if isinstance(model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)) or (
+    if isinstance(model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)) or (
         DeepSpeedEngine is not None and isinstance(model, DeepSpeedEngine)
     ):
-        name = name.removeprefix("module.")
+        while name.startswith("module."):
+            name = name[len("module."):]
🧹 Nitpick comments (2)
modelopt/torch/utils/network.py (2)

90-93: Align is_parallel() with SUPPORTED_WRAPPERS.

Currently it misses FSDP and DeepSpeed. Use the registry to avoid drift.

Apply this diff:

 def is_parallel(model: nn.Module) -> bool:
-    """Check if a PyTorch model is parallelized."""
-    return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
+    """Check if a PyTorch model is wrapped by a supported parallel/distributed wrapper."""
+    return any(isinstance(model, t) for t in SUPPORTED_WRAPPERS)

440-454: Use isinstance instead of exact type match when unwrapping.

Exact type checks can miss subclasses/proxies (e.g., FSDP variants).

Apply this diff:

-        try:
-            if type(model) in SUPPORTED_WRAPPERS:
-                return getattr(model, SUPPORTED_WRAPPERS[type(model)])
+        try:
+            for wrapper_t, attr in SUPPORTED_WRAPPERS.items():
+                if isinstance(model, wrapper_t):
+                    return getattr(model, attr)
         except AttributeError:
             raise ValueError(
                 f"Model of type {type(model)} could not be forcefully unwrapped! Please manually"
                 " unwrap the model before passing it in."
             )

-    if type(model) in SUPPORTED_WRAPPERS:
+    for wrapper_t, attr in SUPPORTED_WRAPPERS.items():
+        if isinstance(model, wrapper_t):
             if raise_error:
                 raise ValueError(msg or f"Model {model} is wrapped by {type(model)}!")
             elif warn:
                 warnings.warn(msg or f"Model {model} is wrapped by {type(model)}; unwrapping...")
-            return getattr(model, SUPPORTED_WRAPPERS[type(model)])
+            return getattr(model, attr)
     return model
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6941586 and 8cfad7c.

📒 Files selected for processing (2)
  • modelopt/torch/opt/dynamic.py (1 hunks)
  • modelopt/torch/utils/network.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/opt/dynamic.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/utils/network.py (2)

73-77: Wrapper registry centralization and DeepSpeed dynamic add: LGTM.

Good move to make SUPPORTED_WRAPPERS the single source of truth and register DeepSpeed at runtime.

Also applies to: 79-86


599-613: Verified: call sites updated and Python requirement OK.

AST scan found no single-arg/zero-arg calls to get_unwrapped_name — all occurrences pass model (modelopt/torch/quantization/utils.py:455, quantization/conversion.py:126/131/169, opt/dynamic.py:1276, opt/plugins/peft.py:93). setup.py declares python_requires=">=3.10,<3.13" (pyproject targets py310), so removeprefix/removesuffix and PEP 585 generics are supported.

Copy link

coderabbitai bot commented Sep 15, 2025

Walkthrough

Adds a DeepSpeed Accelerate config, enables FSDP activation checkpointing, centralizes model-aware unwrapping and quantizer state snapshot/restore, implements a stateful FS‑DP2-aware quantization flow with explicit save/restore, refactors example launch/tests for backend selection and port handling, and removes the sharded-checkpoint conversion script and its automated calls.

Changes

Cohort / File(s) Summary
Accelerate / Backend configs
examples/llm_qat/accelerate_config/deepspeed.yaml, examples/llm_qat/accelerate_config/fsdp1.yaml
Adds deepspeed.yaml (LOCAL_MACHINE, distributed_type=DEEPSPEED, ZeRO-3, mixed_precision=bf16) and flips fsdp_activation_checkpointing to true in fsdp1.yaml.
Launch orchestration
examples/llm_qat/launch.sh
Adds parse_value() and BACKEND default; consolidates backend handling (fsdp1/fsdp, fsdp2, ddp, deepspeed), maps legacy flags, reorders command construction, enables set -x, and removes post-run convert_sharded_ckpt invocation.
Examples: main/train/utils/README
examples/llm_qat/main.py, examples/llm_qat/simple_qat_train.py, examples/llm_qat/utils.py, examples/llm_qat/README.md
Moves get_metrics_with_perplexity into examples/llm_qat/utils.py and uses it in evaluation; main.py simplifies eval flow and logs "Training completed."; simple_qat_train.py makes --quant-cfg a string resolved via getattr(mtq, ...); removes one README comment line.
Removed conversion script & invocation
examples/llm_qat/convert_sharded_ckpt.py, examples/llm_qat/llama_factory/launch_llamafactory.sh
Deletes convert_sharded_ckpt.py and removes its invocation from launch_llamafactory.sh.
Transformers trainer / quantization flow
modelopt/torch/quantization/plugins/transformers_trainer.py
Large refactor to stateful, FS‑DP2-aware quantization: explicit quantizer state management (get/set), single _modelopt_state_path save/restore, subset-based calibration forward loop, Accelerate patch for FS‑DP2, new training/prediction/evaluate methods and updated method signatures; removed several legacy helpers.
Quantizer state utilities & PEFT integration
modelopt/torch/quantization/utils.py, modelopt/torch/opt/plugins/peft.py
Adds get_quantizer_state_dict(model) and set_quantizer_state_dict(model, quantizer_state_dict); PEFT plugin save/load now uses these utilities and model-context keys for per-quantizer state.
Unwrap/name utilities & network support
modelopt/torch/utils/network.py, modelopt/torch/opt/dynamic.py, modelopt/torch/quantization/conversion.py
Extends get_unwrapped_name(name, model) to accept model context, centralizes SUPPORTED_WRAPPERS (adds FSDP, optional DeepSpeedEngine), and updates callers to use model-aware name lookup; DynamicSpace.config() and conversion routines updated accordingly.
Conversion / restore behavior
modelopt/torch/opt/conversion.py
apply_mode/modelopt_state unwrap calls use force_unwrap=True; restore now explicitly unwraps early (unwrap_model(model, raise_error=True)); docstrings adjusted.
Quantization core tweaks
modelopt/torch/quantization/nn/modules/quant_module.py, modelopt/torch/quantization/calib/histogram.py
Ensures _enable_weight_quantization is reset via try/finally in quantize_weight; simplifies distributed calibration warning text.
Tests / test utils
tests/_test_utils/examples/run_command.py, tests/examples/llm_qat/test_llm_qat.py
run_example_command gains setup_free_port to set MASTER_PORT in env; tests parameterized over backends (fsdp1, fsdp2, deepspeed, ddp), add direct-QAT test, and _run_command forwards setup_free_port=True.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Launch as launch.sh
  participant Accel as Accelerate
  participant Trainer as QAT/QAD Trainer
  participant Model as Model
  participant Disk as Storage

  User->>Launch: run example (--backend ...)
  Launch->>Launch: parse args, set BACKEND
  Launch->>Accel: select config (fsdp1/fsdp2/ddp/deepspeed)
  Accel-->>Launch: environment prepared
  Launch->>Trainer: start

  rect rgba(220,235,255,0.35)
    Note right of Trainer: FS‑DP2-specific adjustments
    Trainer->>Trainer: _patch_accelerate_for_fsdp2_fix()
  end

  alt modelopt state exists
    Trainer->>Disk: read modelopt + quantizer state
    Trainer->>Trainer: _restore_modelopt_state_with_weights()
    Trainer->>Model: restore weights and quantizer buffers
  else no prior state
    Trainer->>Model: run subset forward-loop for calibration
    Trainer->>Trainer: apply quantization (mtq.quantize) and optional compress
    Trainer->>Disk: _save_modelopt_state_with_weights()
  end

  Trainer->>User: metrics (perplexity via examples.utils.get_metrics_with_perplexity)
Loading
sequenceDiagram
  autonumber
  participant PEFT as PEFT Plugin
  participant Utils as quantization.utils
  participant Model as Model
  participant Disk as Storage

  PEFT->>Utils: get_quantizer_state_dict(Model)
  Utils->>Model: traverse modules, collect TensorQuantizer.state_dict (keys via get_unwrapped_name(name, model))
  Utils-->>PEFT: quantizer_state_dict
  PEFT->>Disk: write modelopt + quantizer_state_dict

  Disk->>PEFT: read modelopt + quantizer_state_dict
  PEFT->>Utils: set_quantizer_state_dict(Model, quantizer_state_dict)
  Utils->>Model: restore per-quantizer state by model-aware keys
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I twitch my whiskers at config files,
I hop through wrappers, tests and trials.
I stash quant seeds in careful heaps,
Restore them swift from model sleeps.
A little rabbit, backend-wise—now bounding over logs and tries.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.50% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title concisely summarizes the PR’s primary work—QATTrainer training workflow fixes and cleanup plus added backend-specific unit tests—and aligns with the PR objectives and changed files (training/workflow fixes, FSDP2 compatibility, and new tests). It is specific and relevant to the changeset.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch asma/fix_fsdp2_state_fix

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
modelopt/torch/opt/dynamic.py (1)

1281-1335: Bug: select() mismatches keys under wrappers; normalize names before comparisons.

config() now returns unwrapped keys, but select() compares against raw names from named_hparams(). Under DDP/DeepSpeed, this rejects valid configs. Normalize both sides with get_unwrapped_name(..., self.model).

Apply this diff:

 def select(self, config: dict[str, Any], strict: bool = True) -> None:
@@
-        configurables = dict(self.named_hparams(configurable=True))
+        # Normalize keys to work under wrappers (DDP/DeepSpeed).
+        _unwrap = lambda n: get_unwrapped_name(n, self.model)
+        normalized_config = {_unwrap(k): v for k, v in config.items()}
+        configurables_raw = dict(self.named_hparams(configurable=True))
+        configurables = {_unwrap(n): hp for n, hp in configurables_raw.items()}
@@
-        check_non_configurable = any(
-            name in config and name not in configurables for name, hp in self.named_hparams()
-        )
+        all_hps_unwrapped = {_unwrap(n): hp for n, hp in self.named_hparams()}
+        check_non_configurable = any(
+            n in normalized_config and n not in configurables for n in all_hps_unwrapped
+        )
@@
-        unexpected_keys = dict.fromkeys(config.keys(), True)
+        unexpected_keys = dict.fromkeys(normalized_config.keys(), True)
@@
-        for name, hparam in configurables.items():
-            if name in config:
-                hparam.active = config[name]
+        for name, hparam in configurables.items():
+            if name in normalized_config:
+                hparam.active = normalized_config[name]
                 unexpected_keys[name] = False
             elif strict:
                 missing_keys.append(name)
@@
-            for name, hparam in self.named_hparams():
-                if name in configurables:
+            for name, hparam in all_hps_unwrapped.items():
+                if name in configurables:
                     continue
-                if name not in config:
+                if name not in normalized_config:
                     missing_keys.append(name)
                     continue
                 unexpected_keys[name] = False
-                if hparam.active != config[name]:
+                if hparam.active != normalized_config[name]:
                     inconsistent_keys.append(
-                        f"{name}: active={hparam.active}, config={config[name]}"
+                        f"{name}: active={hparam.active}, config={normalized_config[name]}"
                     )
modelopt/torch/quantization/conversion.py (1)

90-93: Critical: convert_to_quantized_model return value ignored.

This leaves model as ModelLikeModule and prevents quantizer insertion before restore.

Apply this diff:

-    convert_to_quantized_model(model, config)
+    model, _ = convert_to_quantized_model(model, config)
examples/llm_qat/launch.sh (1)

63-66: Fix divide-by-zero when no GPUs are visible.

GPU_COUNT can be 0 in CI/CPU envs, causing DEFAULT_SAVE_STEPS=$((192 / 0)).

Apply:

-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
+GPU_COUNT=$(python -c "import torch; import os; print(torch.cuda.device_count())")
+# Calculate save_steps (fallback to 1 when no GPUs)
+DEFAULT_SAVE_STEPS=$((192 / (GPU_COUNT > 0 ? GPU_COUNT : 1)))
♻️ Duplicate comments (3)
modelopt/torch/opt/dynamic.py (1)

1275-1277: Fix: pass nn.Module context to get_unwrapped_name — correct.

Using self.model ensures wrapper prefixes are stripped (DDP/DeepSpeed). Matches prior guidance.

modelopt/torch/quantization/utils.py (1)

446-456: Wrapper-stable quantizer keys — LGTM.

Using get_unwrapped_name(name, model) fixes cross-wrapper key drift. Matches prior guidance.

modelopt/torch/quantization/plugins/transformers_trainer.py (1)

197-201: Guard against missing datasets before len(...).

Order can raise TypeError when both datasets are None. Recommend:

-        dataset = self.train_dataset if self.train_dataset is not None else self.eval_dataset
-        num_samples = min(self.quant_args.calib_size, len(dataset))  # type: ignore [union-attr]
+        dataset = self.eval_dataset if self.eval_dataset is not None else self.train_dataset
+        assert dataset is not None, "Calibration requires either eval or train dataset."
+        num_samples = min(self.quant_args.calib_size, len(dataset))
🧹 Nitpick comments (16)
modelopt/torch/quantization/nn/modules/quant_module.py (2)

157-165: Make quantize_weight re-entrant to avoid disabling on nested contexts

With the current bool flag, nested with self.quantize_weight() blocks will disable weight quantization when the inner context exits. Use a depth counter so the flag only flips to False at depth 0.

Apply:

 @contextlib.contextmanager
 def quantize_weight(self):
     """Context in which `self.weight` is quantized."""
-    self._enable_weight_quantization = True
-    try:
-        yield
-    finally:
-        self._enable_weight_quantization = False
+    depth = getattr(self, "_weight_quantization_depth", 0) + 1
+    self._weight_quantization_depth = depth
+    self._enable_weight_quantization = True
+    try:
+        yield
+    finally:
+        depth -= 1
+        self._weight_quantization_depth = depth
+        if depth == 0:
+            self._enable_weight_quantization = False

181-187: Register a depth counter attribute for re-entrancy

Initialize the counter alongside the existing flag.

Apply:

     def _setup(self):
         super()._setup()
         self._register_temp_attribute(
             "weight_quantizer", TensorQuantizer(self.default_quant_desc_weight)
         )
         self._register_temp_attribute("_enable_weight_quantization", False)
+        self._register_temp_attribute("_weight_quantization_depth", 0)
         self._register_dynamic_attribute("weight", self._get_quantized_weight)
examples/llm_qat/utils.py (1)

172-175: Guard against missing/invalid eval_loss and avoid unnecessary tensor ops

Handle absent eval_loss/loss keys and non-finite values; keep it torch-free for speed.

Apply:

-def get_metrics_with_perplexity(metrics):
-    """Add perplexity to the metrics."""
-    metrics = {"perplexity": float(torch.exp(torch.tensor(metrics["eval_loss"]))), **metrics}
-    return metrics
+def get_metrics_with_perplexity(metrics):
+    """Add perplexity to the metrics if loss is present."""
+    loss = metrics.get("eval_loss", metrics.get("loss"))
+    if loss is None:
+        return metrics
+    try:
+        ppl = float(torch.exp(torch.tensor(loss)))
+    except Exception:
+        return metrics
+    return {"perplexity": ppl, **metrics}
modelopt/torch/quantization/calib/histogram.py (1)

160-163: Neutral wording: not all setups are “DDP”

Since this path can run under FSDP/DeepSpeed too, avoid “DDP modules” phrasing.

Apply:

-                " method is to use the same calibration dataset across all distributed data"
-                " parallel groups so that `amax` is the same for all DDP modules."
+                " method is to use the same calibration dataset across all distributed data"
+                " parallel groups so that `amax` is consistent across all modules."
tests/_test_utils/examples/run_command.py (1)

35-44: Set MASTER_ADDR for robustness on multi-NIC hosts

Helps avoid rendezvous surprises on machines with multiple interfaces.

Apply:

 def run_example_command(cmd_parts: list[str], example_path: str, setup_free_port: bool = False):
     print(f"[{example_path}] Running command: {cmd_parts}")
     env = os.environ.copy()
 
     if setup_free_port:
         free_port = get_free_port()
         env["MASTER_PORT"] = str(free_port)
+        env.setdefault("MASTER_ADDR", "127.0.0.1")
 
     subprocess.run(cmd_parts, cwd=MODELOPT_ROOT / "examples" / example_path, env=env, check=True)
modelopt/torch/utils/network.py (3)

79-83: Narrow bare except to ImportError

Avoid swallowing unrelated runtime errors during import.

Apply:

-try:
+try:
     from deepspeed.runtime.engine import DeepSpeedEngine
-except:  # noqa: E722
+except ImportError:
     DeepSpeedEngine = None

438-454: Use isinstance against SUPPORTED_WRAPPERS for unwrapping (handles subclasses too)

Current type(model) in SUPPORTED_WRAPPERS misses subclasses and proxies.

Apply:

-    if force_unwrap:
-        try:
-            if type(model) in SUPPORTED_WRAPPERS:
-                return getattr(model, SUPPORTED_WRAPPERS[type(model)])
+    if force_unwrap:
+        try:
+            for wrapper_type, attr in SUPPORTED_WRAPPERS.items():
+                if isinstance(model, wrapper_type):
+                    return getattr(model, attr)
         except AttributeError:
             raise ValueError(
                 f"Model of type {type(model)} could not be forcefully unwrapped! Please manually"
                 " unwrap the model before passing it in."
             )
 
-    if type(model) in SUPPORTED_WRAPPERS:
+    for wrapper_type, attr in SUPPORTED_WRAPPERS.items():
+        if isinstance(model, wrapper_type):
             if raise_error:
                 raise ValueError(msg or f"Model {model} is wrapped by {type(model)}!")
             elif warn:
                 warnings.warn(msg or f"Model {model} is wrapped by {type(model)}; unwrapping...")
-        return getattr(model, SUPPORTED_WRAPPERS[type(model)])
+            return getattr(model, attr)
     return model

90-93: is_parallel should align with SUPPORTED_WRAPPERS (include FSDP/DeepSpeed when present)

Leverage the canonical wrapper list to avoid drift.

Apply:

 def is_parallel(model: nn.Module) -> bool:
     """Check if a PyTorch model is parallelized."""
-    return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
+    return isinstance(model, tuple(SUPPORTED_WRAPPERS.keys()))
examples/llm_qat/simple_qat_train.py (2)

118-123: Calibrate without grads and restore mode.

Avoid unnecessary autograd and keep model mode intact.

Apply this diff:

-    def calibrate(m: nn.Module):
-        for batch in calib_dataloader:
-            m(batch["input_ids"].to(device))
+    def calibrate(m: nn.Module):
+        was_training = m.training
+        m.eval()
+        with torch.no_grad():
+            for batch in calib_dataloader:
+                m(batch["input_ids"].to(device))
+        if was_training:
+            m.train()

109-117: Nit: avoid double .cuda()/.to(device).

Define device before model creation and move once.

Apply this diff:

-    # Load model and initialize loss
-    model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda()
-    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
-
-    # Get dataloaders
-    train_dataloader, calib_dataloader = get_dataloader(args, tokenizer)
-
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    # Load model and initialize loss
+    model = AutoModelForCausalLM.from_pretrained(args.model_path).to(device)
+    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+
+    # Get dataloaders
+    train_dataloader, calib_dataloader = get_dataloader(args, tokenizer)
@@
-    model.to(device)
+    # already moved above

Also applies to: 130-132

examples/llm_qat/main.py (1)

190-191: Typo: stray characters in comment.

Remove the trailing “åå”.

Apply this diff:

-    # Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file.åå
+    # Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file.
modelopt/torch/quantization/utils.py (1)

459-466: Make restore tolerant; early-exit on empty dict.

Loading with strict=True can fail on minor shape/version drift; also no-op fast path helps. Suggest:

-def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict):
+def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict, strict: bool = True):
     """Set the state dict of the quantizers in the model."""
     from .nn import TensorQuantizer
 
+    if not quantizer_state_dict:
+        return
     for name, module in model.named_modules():
         key = get_unwrapped_name(name, model)
         if isinstance(module, TensorQuantizer) and key in quantizer_state_dict:
-            module.load_state_dict(quantizer_state_dict[key])
+            module.load_state_dict(quantizer_state_dict[key], strict=strict)
examples/llm_qat/launch.sh (3)

52-55: Print the invalid option, not its value.

Current message shows ${1#*=} which is the value. Use $1.

-      >&2 printf "Error: Invalid argument ${1#*=}\n"
+      >&2 printf "Error: Invalid argument %s\n" "$1"

88-93: Quote variables in tests to avoid word-splitting.

Prevent surprises if QUANT_CFG contains spaces.

-if [ -z $QUANT_CFG ]; then
+if [ -z "${QUANT_CFG:-}" ]; then
   QUANT_ARGS=""
 else
-  QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE"
+  QUANT_ARGS="--quant_cfg \"$QUANT_CFG\" --calib_size \"$CALIB_SIZE\""
 fi

94-99: Quote MAX_STEPS check.

Avoid unary operator errors when unset.

-if [ ! -z $MAX_STEPS ]; then
+if [ -n "${MAX_STEPS:-}" ]; then
   OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS"
 fi
modelopt/torch/quantization/plugins/transformers_trainer.py (1)

202-208: Unused param: clarify intent of forward_loop’s model arg.

Rename to underscore to avoid confusion with self.model.

-        def forward_loop(model):
+        def forward_loop(_):
             for batch in tqdm(data_loader, desc="Calibrating"):
                 batch = self._prepare_inputs(batch)
                 # Important: We should forward pass using the unwrapped model
                 # mtq.quantize will unwrap the model pass the unwrapped model to the forward_loop
                 self.model(**batch)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 76e8ce2 and ecfc043.

📒 Files selected for processing (18)
  • examples/llm_qat/accelerate_config/deepspeed.yaml (1 hunks)
  • examples/llm_qat/accelerate_config/fsdp1.yaml (1 hunks)
  • examples/llm_qat/convert_sharded_ckpt.py (0 hunks)
  • examples/llm_qat/launch.sh (4 hunks)
  • examples/llm_qat/main.py (2 hunks)
  • examples/llm_qat/simple_qat_train.py (2 hunks)
  • examples/llm_qat/utils.py (1 hunks)
  • modelopt/torch/opt/conversion.py (2 hunks)
  • modelopt/torch/opt/dynamic.py (1 hunks)
  • modelopt/torch/opt/plugins/peft.py (2 hunks)
  • modelopt/torch/quantization/calib/histogram.py (1 hunks)
  • modelopt/torch/quantization/conversion.py (2 hunks)
  • modelopt/torch/quantization/nn/modules/quant_module.py (1 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (5 hunks)
  • modelopt/torch/quantization/utils.py (2 hunks)
  • modelopt/torch/utils/network.py (3 hunks)
  • tests/_test_utils/examples/run_command.py (1 hunks)
  • tests/examples/llm_qat/test_llm_qat.py (3 hunks)
💤 Files with no reviewable changes (1)
  • examples/llm_qat/convert_sharded_ckpt.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T16:36:42.871Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: examples/llm_qat/accelerate_config/deepspeed.yaml:17-17
Timestamp: 2025-09-15T16:36:42.871Z
Learning: In Hugging Face Accelerate configuration YAML files, num_processes can accept the string "gpu" as a value, which gets passed through to torch run under the hood and functions correctly, despite standard documentation showing integer values.

Applied to files:

  • examples/llm_qat/accelerate_config/deepspeed.yaml
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (8)
modelopt/torch/opt/dynamic.py (1)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/quantization/conversion.py (4)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • set_from_modelopt_state (1122-1140)
  • get_modelopt_state (1105-1120)
  • get_modelopt_state (1246-1248)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/nn/modules/quant_module.py (1)
  • QuantModule (37-114)
modelopt/torch/quantization/utils.py (3)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • TensorQuantizer (62-1182)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
tests/examples/llm_qat/test_llm_qat.py (1)
tests/examples/conftest.py (1)
  • tiny_llama_path (33-41)
examples/llm_qat/main.py (3)
examples/llm_qat/utils.py (1)
  • get_metrics_with_perplexity (172-175)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
  • evaluate (240-247)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/opt/conversion.py (1)
modelopt/torch/utils/network.py (1)
  • unwrap_model (430-454)
modelopt/torch/opt/plugins/peft.py (2)
modelopt/torch/quantization/utils.py (1)
  • get_quantizer_state_dict (446-456)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • TensorQuantizer (62-1182)
  • collect (1166-1176)
modelopt/torch/quantization/utils.py (4)
  • calibrate_with_adapters (275-286)
  • get_quantizer_state_dict (446-456)
  • is_quantized (239-243)
  • set_quantizer_state_dict (459-466)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/opt/conversion.py (3)
  • modelopt_state (444-486)
  • save (489-507)
  • restore_from_modelopt_state (510-567)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (19)
examples/llm_qat/accelerate_config/deepspeed.yaml (1)

1-23: DS config looks good; num_processes: gpu accepted per project practice

Acknowledging prior learning in this repo that "gpu" is passed through to torch run and works as intended.

If you want, I can add a quick check in CI to print the resolved launch args from Accelerate for visibility.

modelopt/torch/utils/network.py (1)

605-609: Name cleaning logic: OK

Stripping "module." for DP/DDP/DS and then delegating to FSDP’s utility is correct.

tests/_test_utils/examples/run_command.py (1)

35-44: Pass setup_free_port=True for distributed example runs

Add setup_free_port=True to run_example_command calls that launch multi-process/backends (accelerate launch, ./launch.sh, torchrun).

  • tests/examples/llm_distill/test_llm_distill.py:23 (accelerate launch --multi_gpu)
  • tests/examples/llm_qat/test_llm_qat.py: update _run_command (line 23) and its callers (lines 50, 61, 78, 89, 102, 103)
  • tests/examples/speculative_decoding/test_eagle.py:22 (./launch.sh)
  • tests/examples/speculative_decoding/test_medusa.py:33, 51, 73 (./launch.sh calls)
modelopt/torch/opt/conversion.py (2)

383-383: Force-unwrapping in apply_mode — good call.

Ensures consistent behavior across wrappers before conversion.


590-592: Explicit unwrap on restore — correct.

Rejecting wrapped models up-front improves clarity and avoids subtle failures.

examples/llm_qat/accelerate_config/fsdp1.yaml (1)

7-7: Enabling FSDP activation checkpointing — verify interaction with training args.

Ensure TrainingArguments.gradient_checkpointing=True (with use_reentrant) and this setting do not double-apply or conflict for the chosen transformers/accelerate versions in CI.

examples/llm_qat/simple_qat_train.py (2)

88-93: CLI default now a string name — OK if attributes exist on mtq.

choices=mtq.config.choices should match attribute names. Looks fine.


124-125: Using getattr(mtq, args.quant_cfg) — correct given string default.

No issues.

modelopt/torch/quantization/conversion.py (1)

124-133: State keying now model‑contextual — LGTM.

Passing model into get_unwrapped_name aligns save/restore across wrappers.

Also applies to: 166-172

examples/llm_qat/main.py (1)

259-269: Eval path simplification with perplexity — LGTM.

Straightforward and clearer. Uses rank‑0 logging properly.

tests/examples/llm_qat/test_llm_qat.py (1)

36-44: Backend‑parameterized tests — nice coverage; watch runtime.

Great to exercise fsdp1/fsdp2/deepspeed/ddp. If CI time grows, consider marking deepspeed as nightly/slow.

Also applies to: 45-58, 70-86

modelopt/torch/opt/plugins/peft.py (2)

60-66: Centralized quantizer state save — good refactor.

Avoids state_dict() pitfalls under FSDP and keeps logic in one place.


91-94: Load with model‑context unwrapped names — correct.

Indexes match saved dict under wrappers.

examples/llm_qat/launch.sh (4)

21-25: Argument parser looks good.

Handles both --arg=value and --arg value forms cleanly.


99-105: Deprecation path — LGTM.

Clear message; keeps backward-compat for --use_fsdp2.


110-136: Backend switch is tidy.

Centralized selection and GC args per backend read well.


138-146: Scoped FSDP RAM-efficient toggle — LGTM.

Only applied for fsdp1/fsdp2 when distill, which matches the note.

modelopt/torch/quantization/plugins/transformers_trainer.py (2)

251-259: Final FULL_STATE_DICT toggle — LGTM.

Good safeguard before the last checkpoint write.


163-169: Verify no HF/Accelerate prepare/load runs before restore (FSDP2/DDP/DeepSpeed)

  • QATTrainer patches accelerate.prepare then immediately restores modelopt state — modelopt/torch/quantization/plugins/transformers_trainer.py:L163-L169.
  • ModelOptHFTrainer.init calls enable_huggingface_checkpointing() then super().init (HF Trainer) — modelopt/torch/opt/plugins/transformers.py:L156-L159; check transformers.Trainer.init (for your transformers + accelerate versions) for any accelerator.prepare(), torch.load()/from_pretrained() or checkpoint hooks that run during init (these would run before the QAT restore).
  • If unsure, reproduce with a minimal run that logs accelerator.prepare / torch.load invocation order under your target backends.

@realAsma realAsma force-pushed the asma/fix_fsdp2_state_fix branch from ab384f0 to c2eb5e5 Compare September 15, 2025 18:39
@realAsma realAsma force-pushed the asma/fix_fsdp2_state_fix branch 3 times, most recently from 1b14ede to 7d612fc Compare September 15, 2025 20:07
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
examples/llm_qat/launch.sh (1)

63-66: Avoid divide-by-zero and ensure save_steps ≥ 1.

GPU_COUNT can be 0 (CPU runs) or large enough to make 192/GPU_COUNT evaluate to 0. Guard and clamp to at least 1.

Apply:

-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
+# Calculate save_steps (guard 0 GPUs and clamp to ≥1)
+if [ "${GPU_COUNT:-0}" -gt 0 ]; then
+  DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
+  if [ "$DEFAULT_SAVE_STEPS" -lt 1 ]; then DEFAULT_SAVE_STEPS=1; fi
+else
+  DEFAULT_SAVE_STEPS=192
+fi
examples/llm_qat/simple_qat_train.py (2)

109-117: Unconditional .cuda() will crash on CPU-only; move device selection earlier and use .to(device).

Load, then move model to the selected device.

-    # Load model and initialize loss
-    model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda()
-    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
-    # Get dataloaders
-    train_dataloader, calib_dataloader = get_dataloader(args, tokenizer)
-
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    # Load model and tokenizer
+    model = AutoModelForCausalLM.from_pretrained(args.model_path).to(device)
+    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+    # Get dataloaders
+    train_dataloader, calib_dataloader = get_dataloader(args, tokenizer)

118-125: Calibrate without gradients and with eval mode to save memory and time.

Wrap calibration in inference_mode and temporarily switch to eval().

-    def calibrate(m: nn.Module):
-        for batch in calib_dataloader:
-            m(batch["input_ids"].to(device))
+    def calibrate(m: nn.Module):
+        was_training = m.training
+        m.eval()
+        with torch.inference_mode():
+            for batch in calib_dataloader:
+                m(batch["input_ids"].to(device))
+        if was_training:
+            m.train()
modelopt/torch/quantization/plugins/transformers_trainer.py (1)

353-359: export_student path doesn’t actually save the exported model under FSDP2 SHARDED_STATE_DICT.

You export into a local var and then call QATTrainer.save_model, which saves self.model (not the exported one). Swap in the exported model temporarily or save it directly.

-            if export_student:
-                model = self.accelerator.unwrap_model(self.model)
-                model = model.export()
-            return QATTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs)
+            if export_student:
+                exported = self.accelerator.unwrap_model(self.model).export()
+                save_dir = output_dir or self.args.output_dir
+                exported.save_pretrained(save_dir)
+                return
+            return QATTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs)
♻️ Duplicate comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)

203-210: Use of self.model in forward_loop is correct here.

Per maintainer guidance for this flow, forwarding via self.model (not the unwrapped param) is intended.

🧹 Nitpick comments (8)
examples/llm_qat/launch.sh (3)

88-92: Quote QUANT_CFG in tests to avoid word-splitting/unset issues.

Use parameter expansion with default to prevent “[ -z ]” pitfalls when unset.

-if [ -z $QUANT_CFG ]; then
+if [ -z "${QUANT_CFG:-}" ]; then
   QUANT_ARGS=""
 else
   QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE"
 fi

95-97: Quote MAX_STEPS check.

Minor shell robustness for unset values.

-if [ ! -z $MAX_STEPS ]; then
+if [ -n "${MAX_STEPS:-}" ]; then
   OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS"
 fi

52-55: Fix invalid-arg error message.

Currently prints only the value part; print the flag itself.

-      >&2 printf "Error: Invalid argument ${1#*=}\n"
+      >&2 printf "Error: Invalid argument %s\n" "$1"
examples/llm_qat/simple_qat_train.py (1)

46-51: Optional: tune DataLoader for GPU throughput.

Pin memory and allow workers to reduce host-device stalls.

-    train_dataloader = DataLoader(
-        train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn
-    )
-    calib_dataloader = DataLoader(
-        calib_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn
-    )
+    common_dl = dict(collate_fn=collate_fn, pin_memory=torch.cuda.is_available(), num_workers=2)
+    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **common_dl)
+    calib_dataloader = DataLoader(calib_dataset, batch_size=args.batch_size, shuffle=False, **common_dl)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)

241-249: Eval-only FSDP2 workaround: add safe param access.

next(self.model.parameters()) can raise StopIteration for paramless wrappers; guard defensively.

-            dummy_optimizer = torch.optim.SGD([next(self.model.parameters())], lr=0.0)
+            first_param = next((p for p in self.model.parameters()), None)
+            if first_param is None:
+                return super().evaluate(*args, **kwargs)
+            dummy_optimizer = torch.optim.SGD([first_param], lr=0.0)

250-261: Robust check for FULL_STATE_DICT.

State-dict type may be an enum/obj; compare via string to avoid false negatives.

-        if (
-            (not self.is_in_train)
-            and self.is_fsdp_enabled
-            and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT"
-        ):
+        if (
+            (not self.is_in_train)
+            and self.is_fsdp_enabled
+            and "FULL_STATE_DICT" not in str(self.accelerator.state.fsdp_plugin.state_dict_type)
+        ):

262-294: Optional: guard against double-patching accelerate.prepare.

Avoid overwriting _original_prepare if already patched.

-        self.accelerator._original_prepare = self.accelerator.prepare
+        if getattr(self.accelerator, "_original_prepare", None) is None:
+            self.accelerator._original_prepare = self.accelerator.prepare
         self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)

321-323: Avoid unconditional .cuda() in QAD flow.

Either assert CUDA availability or move to device with fallback to improve error messaging.

-        self.model.cuda()
+        if not torch.cuda.is_available():
+            raise RuntimeError("QAD requires CUDA; no GPU detected.")
+        self.model.cuda()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ecfc043 and 7d612fc.

📒 Files selected for processing (19)
  • examples/llm_qat/README.md (0 hunks)
  • examples/llm_qat/accelerate_config/deepspeed.yaml (1 hunks)
  • examples/llm_qat/accelerate_config/fsdp1.yaml (1 hunks)
  • examples/llm_qat/convert_sharded_ckpt.py (0 hunks)
  • examples/llm_qat/launch.sh (4 hunks)
  • examples/llm_qat/main.py (2 hunks)
  • examples/llm_qat/simple_qat_train.py (2 hunks)
  • examples/llm_qat/utils.py (1 hunks)
  • modelopt/torch/opt/conversion.py (2 hunks)
  • modelopt/torch/opt/dynamic.py (1 hunks)
  • modelopt/torch/opt/plugins/peft.py (2 hunks)
  • modelopt/torch/quantization/calib/histogram.py (1 hunks)
  • modelopt/torch/quantization/conversion.py (2 hunks)
  • modelopt/torch/quantization/nn/modules/quant_module.py (1 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (5 hunks)
  • modelopt/torch/quantization/utils.py (2 hunks)
  • modelopt/torch/utils/network.py (3 hunks)
  • tests/_test_utils/examples/run_command.py (1 hunks)
  • tests/examples/llm_qat/test_llm_qat.py (3 hunks)
💤 Files with no reviewable changes (2)
  • examples/llm_qat/README.md
  • examples/llm_qat/convert_sharded_ckpt.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/quantization/calib/histogram.py
🚧 Files skipped from review as they are similar to previous changes (13)
  • examples/llm_qat/accelerate_config/fsdp1.yaml
  • examples/llm_qat/utils.py
  • tests/_test_utils/examples/run_command.py
  • examples/llm_qat/accelerate_config/deepspeed.yaml
  • modelopt/torch/quantization/utils.py
  • modelopt/torch/opt/conversion.py
  • modelopt/torch/opt/dynamic.py
  • tests/examples/llm_qat/test_llm_qat.py
  • modelopt/torch/quantization/conversion.py
  • modelopt/torch/opt/plugins/peft.py
  • examples/llm_qat/main.py
  • modelopt/torch/quantization/nn/modules/quant_module.py
  • modelopt/torch/utils/network.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • TensorQuantizer (62-1182)
  • collect (1166-1176)
modelopt/torch/quantization/utils.py (5)
  • calibrate_with_adapters (275-286)
  • disable_lora_quantizers_in_config (289-296)
  • get_quantizer_state_dict (446-456)
  • is_quantized (239-243)
  • set_quantizer_state_dict (459-466)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/opt/conversion.py (3)
  • modelopt_state (444-486)
  • save (489-507)
  • restore_from_modelopt_state (510-567)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (8)
examples/llm_qat/launch.sh (3)

99-108: Backward-compat for --use_fsdp2 and compress→ddp logic looks good.

Clear precedence and explicit deprecation notice. LGTM.


110-136: Backend switch-case is clean and explicit.

Good normalization to fsdp1/fsdp, fsdp2, ddp, deepspeed, and targeted args. LGTM.


139-146: FSDP memory-efficient loading toggle scoped to fsdp backends only.

Good constraint; avoids impacting ddp/deepspeed. LGTM.

examples/llm_qat/simple_qat_train.py (1)

88-93: CLI default-to-string with getattr is correct.

Switch to string default plus getattr aligns with choices and keeps args portable. LGTM.

modelopt/torch/quantization/plugins/transformers_trainer.py (4)

163-169: Good: restore ModelOpt state early if present.

Early restore path is aligned with the PR goal (fix FSDP2 and QLoRA interactions). LGTM.


195-201: Calibration dataset selection fix looks correct.

Select dataset before len() avoids None handling pitfalls. LGTM.


211-218: Quantize → optional compress flow is sensible.

Order and gating are appropriate. LGTM.


229-234: Lazy quantization trigger on first step is fine.

Ensures quantize runs once. LGTM.

Copy link

codecov bot commented Sep 15, 2025

Codecov Report

❌ Patch coverage is 72.72727% with 9 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.85%. Comparing base (8a5736a) to head (c4bba4f).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/utils.py 64.28% 5 Missing ⚠️
modelopt/torch/utils/network.py 66.66% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #318      +/-   ##
==========================================
- Coverage   73.86%   73.85%   -0.02%     
==========================================
  Files         172      172              
  Lines       17409    17430      +21     
==========================================
+ Hits        12860    12873      +13     
- Misses       4549     4557       +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)

353-359: export_student is ignored under FSDP2; exported model never saved

Local model = ...export() is unused and QATTrainer.save_model doesn’t accept export_student, so the student export is effectively dropped. Temporarily swap self.model to the exported student for saving and restore afterward.

-        if self.accelerator.is_fsdp2 and "SHARDED_STATE_DICT" in str(
-            self.accelerator.state.fsdp_plugin.state_dict_type
-        ):
-            if export_student:
-                model = self.accelerator.unwrap_model(self.model)
-                model = model.export()
-            return QATTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs)
+        if self.accelerator.is_fsdp2 and "SHARDED_STATE_DICT" in str(
+            self.accelerator.state.fsdp_plugin.state_dict_type
+        ):
+            original_model = self.model
+            try:
+                if export_student:
+                    exported = self.accelerator.unwrap_model(self.model).export()
+                    self.model = exported
+                return QATTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs)
+            finally:
+                self.model = original_model
♻️ Duplicate comments (2)
modelopt/torch/quantization/plugins/transformers_trainer.py (2)

197-201: LGTM: dataset selection order fixed

Selecting dataset before calling len(dataset) avoids None errors previously flagged.


278-289: to_empty() can wipe restored quantizer buffers; snapshot/restore around prepare

Calling to_empty() before FSDP2 prepare can drop quantizer buffer storage (problematic for eval-only flows after restoring state). Snapshot quantizer state before the loop and restore it after prepare.

             tq_og_non_prsist_buffers = {}
+            saved_tq_state = get_quantizer_state_dict(model)
             for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)):
                 tq.to_empty(device=self.device)
                 tq_og_non_prsist_buffers[tq] = tq._non_persistent_buffers_set.copy()
                 tq._non_persistent_buffers_set.update(tq._buffers.keys())
@@
             outputs = self._original_prepare(*args, **kwargs)
 
             for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)):
                 tq._non_persistent_buffers_set.clear()
                 tq._non_persistent_buffers_set = tq_og_non_prsist_buffers[tq]
+            set_quantizer_state_dict(model, saved_tq_state)

Optional nit: fix the variable name typo for readability.

-            tq_og_non_prsist_buffers = {}
+            tq_og_non_persistent_buffers = {}
@@
-                tq_og_non_prsist_buffers[tq] = tq._non_persistent_buffers_set.copy()
+                tq_og_non_persistent_buffers[tq] = tq._non_persistent_buffers_set.copy()
@@
-                tq._non_persistent_buffers_set = tq_og_non_prsist_buffers[tq]
+                tq._non_persistent_buffers_set = tq_og_non_persistent_buffers[tq]
🧹 Nitpick comments (3)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)

253-260: StateDictType comparison is brittle; string vs enum mismatch

state_dict_type != "FULL_STATE_DICT" will always be True if it’s an enum. Use string containment (as done elsewhere) or compare to the enum.

-            and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT"
+            and "FULL_STATE_DICT" not in str(self.accelerator.state.fsdp_plugin.state_dict_type)

292-293: Avoid double‑patching Accelerator.prepare

Guard _original_prepare assignment to prevent recursion if patched more than once.

-        self.accelerator._original_prepare = self.accelerator.prepare
-        self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)
+        if not hasattr(self.accelerator, "_original_prepare"):
+            self.accelerator._original_prepare = self.accelerator.prepare
+        self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)

203-210: Forward loop param unused; mute tqdm on non‑master and clarify comment

Rename the unused param and avoid multi‑process progress spam. Keep using self.model(**batch) as intended.

-        def forward_loop(model):
-            for batch in tqdm(data_loader, desc="Calibrating"):
+        def forward_loop(_):
+            for batch in tqdm(
+                data_loader, desc="Calibrating", disable=not self.accelerator.is_local_main_process
+            ):
                 batch = self._prepare_inputs(batch)
-                # Important: We should forward pass using the unwrapped model
-                # mtq.quantize will unwrap the model pass the unwrapped model to the forward_loop
+                # Intentionally call self.model; quantize() manages unwrapping internally.
                 self.model(**batch)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7d612fc and 77c48fe.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (5 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T20:46:28.997Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:28.997Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • TensorQuantizer (62-1182)
  • collect (1166-1176)
modelopt/torch/quantization/utils.py (5)
  • calibrate_with_adapters (275-286)
  • disable_lora_quantizers_in_config (289-296)
  • get_quantizer_state_dict (446-456)
  • is_quantized (239-243)
  • set_quantizer_state_dict (459-466)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/opt/conversion.py (3)
  • modelopt_state (444-486)
  • save (489-507)
  • restore_from_modelopt_state (510-567)
modelopt/torch/quantization/model_quant.py (2)
  • forward_loop (95-96)
  • quantize (132-227)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (3)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)

206-208: LGTM: forward pass uses self.model per intended design

This aligns with the agreed behavior for the calibration loop.


158-162: Confirm QuantizeConfig mapping semantics in AWQ/SmoothQuant check

quant_cfg.get("algorithm", {}) assumes dict‑like behavior. If quant_cfg can be a QuantizeConfig object, verify it supports .get(); otherwise normalize to a dict first.


165-169: LGTM: restore ModelOpt state early

Restoring before any weight load matches the FSDP2/QLoRA requirement and unblocks the workflow.

@ChenhanYu
Copy link
Collaborator

ChenhanYu commented Sep 15, 2025

@kinjalpatel27
Copy link
Contributor

Does the conversion from sharded_dict to full_dict for saving work for slightly larger model (may be 8B/14B), especially with QAD?

self.model(**batch)

# TODO: Remove calibrate_with_adpaters - this should not be needed
with calibrate_with_adapters(self.model, self.args):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sugunav14 why do we need to use calibrate_with_adapters context manager? I doubt if it is truly needed.

@mxinO
Copy link

mxinO commented Sep 16, 2025

LGTM, thanks for the fix!

@realAsma
Copy link
Contributor Author

Does the conversion from sharded_dict to full_dict for saving work for slightly larger model (may be 8B/14B), especially with QAD?

@kinjalpatel27 I have tested 8B model with QAT. Let me double check with QAD (I am assuming it should work since QAD exports the student model before saving). Could you please take a look at other things?

@kinjalpatel27
Copy link
Contributor

@realAsma generally LGTM, two comments:

  1. Llamafactory launch script also calls convert_sharded_ckpt: https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/asma/fix_fsdp2_state_fix/examples/llm_qat/llama_factory/launch_llamafactory.sh#L259, can you please remove that if the QAD + FSDP2 works fine?
  2. I recently noticed the description in argparser of simple_qat_train.pyis incorrect, can you please fix that too?

Will you be updating the readme in a later PR?

Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving as codeowner to unblock

@kevalmorabia97 kevalmorabia97 removed the request for review from ChenhanYu September 16, 2025 06:37
@realAsma
Copy link
Contributor Author

2. I recently noticed the description in argparser of simple_qat_train.pyis incorrect, can you please fix that too?

fixed, there were some additional bugs in the script - fixed them too.

Will you be updating the readme in a later PR?
Yes I have another PR planned cleaning up QAD workflow. I will update the README in that PR.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/llm_qat/launch.sh (1)

62-66: Fix division-by-zero and clamp DEFAULT_SAVE_STEPS to ≥1.

On CPU-only runs (GPU_COUNT=0) the arithmetic fails. With large GPU counts, result may be 0 (invalid).

Apply:

 GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
+# Calculate save_steps safely
+if ! [[ "$GPU_COUNT" =~ ^[0-9]+$ ]] || [ "$GPU_COUNT" -lt 1 ]; then
+  GPU_COUNT=1
+fi
+DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
+if [ "$DEFAULT_SAVE_STEPS" -lt 1 ]; then
+  DEFAULT_SAVE_STEPS=1
+fi
🧹 Nitpick comments (10)
examples/llm_qat/launch.sh (6)

29-52: Tighten option patterns or guard for collisions.

Patterns like --model* can match unintended flags (e.g., --model_name...). Consider exact matches to avoid accidental captures.

Example:

-    --model*)                                   MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
+    --model|--model=*)                           MODEL=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;

Repeat for other options as needed.


88-98: Quote variables in tests to avoid word-splitting/globbing.

Unquoted vars in [ -z $QUANT_CFG ] and similar can misbehave.

Apply:

-if [ -z $QUANT_CFG ]; then
+if [ -z "${QUANT_CFG:-}" ]; then
   QUANT_ARGS=""
 else
   QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE"
 fi

 OPTIONAL_ARGS=""
-if [ ! -z $MAX_STEPS ]; then
+if [ -n "${MAX_STEPS:-}" ]; then
   OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS"
 fi

99-107: Override to ddp on --compress: add an explicit warning.

You silently override a user-selected backend. Surface a warning to avoid surprise.

Apply:

-if [[ "${COMPRESS,,}" == "true" ]]; then
-  BACKEND="ddp"
+if [[ "${COMPRESS,,}" == "true" ]]; then
+  echo "Info: --compress enabled; forcing --backend=ddp (FSDP not supported with compression)." >&2
+  BACKEND="ddp"
 fi

148-179: Quote CLI argument values in CMD for safety.

Paths or model ids with spaces/shell metacharacters can break the command.

Apply:

-CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
+CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
     main.py \
-    --model_name_or_path $MODEL \
-    --model_max_length $MAX_SEQ_LENGTH \
+    --model_name_or_path \"$MODEL\" \
+    --model_max_length \"$MAX_SEQ_LENGTH\" \
     --dataloader_drop_last True \
-    --do_train $DO_TRAIN \
+    --do_train \"$DO_TRAIN\" \
     --do_eval True \
-    --output_dir $OUTPUT_DIR \
-    --dataset $DATASET \
-    --train_size $TRAIN_SIZE \
-    --eval_size $EVAL_SIZE \
-    --num_train_epochs $NUM_EPOCHS \
-    --per_device_train_batch_size $TRAIN_BS \
-    --per_device_eval_batch_size $EVAL_BS \
-    --gradient_accumulation_steps $ACCUM_STEPS \
+    --output_dir \"$OUTPUT_DIR\" \
+    --dataset \"$DATASET\" \
+    --train_size \"$TRAIN_SIZE\" \
+    --eval_size \"$EVAL_SIZE\" \
+    --num_train_epochs \"$NUM_EPOCHS\" \
+    --per_device_train_batch_size \"$TRAIN_BS\" \
+    --per_device_eval_batch_size \"$EVAL_BS\" \
+    --gradient_accumulation_steps \"$ACCUM_STEPS\" \
     --eval_accumulation_steps 1 \
     --save_strategy steps \
-    --save_steps $SAVE_STEPS \
+    --save_steps \"$SAVE_STEPS\" \
     --eval_strategy steps \
-    --eval_steps $SAVE_STEPS \
+    --eval_steps \"$SAVE_STEPS\" \
     --load_best_model_at_end True \
     --save_total_limit 2 \
-    --learning_rate $LR \
+    --learning_rate \"$LR\" \
     --weight_decay 0.0 \
     --warmup_ratio 0.1 \
     --lr_scheduler_type linear \
     --logging_steps 1 \
     --report_to tensorboard \
-    --lora $LORA \
-    --compress $COMPRESS \
+    --lora \"$LORA\" \
+    --compress \"$COMPRESS\" \
       $GRADIENT_CHECKPOINTING_ARGS $QUANT_ARGS $OPTIONAL_ARGS $DISTILLATION_ARGS
 "

53-55: Minor: error message should print the offending token verbatim.

${1#*=} can mangle strings; use $1.

Apply:

-      >&2 printf "Error: Invalid argument ${1#*=}\n"
+      >&2 printf "Error: Invalid argument %s\n" "$1"

60-60: Trace mode: keep or gate by VERBOSE.

set -x is useful for CI but noisy for users. Consider gating on VERBOSE=1.

Example:

- set -x
+ [[ "${VERBOSE:-0}" == "1" ]] && set -x
modelopt/torch/quantization/plugins/transformers_trainer.py (4)

199-201: Add error handling for missing state file.

The method calls restore_modelopt_state_with_weights without checking if the file exists, which could raise a FileNotFoundError. This is inconsistent with the check performed in __init__ at lines 177-178.

Apply this diff to add error handling:

 def _restore_modelopt_state_with_weights(self):
+    if not os.path.exists(self._modelopt_state_path):
+        print_rank_0(f"ModelOpt state file not found: {self._modelopt_state_path}")
+        return
     restore_modelopt_state_with_weights(self.model, self._modelopt_state_path)
     print_rank_0("Restored modelopt state with weights.")

218-219: Consider removing the TODO for calibrate_with_adapters.

The TODO comment questions whether calibrate_with_adapters is needed, but since the PR objectives mention QLoRA support and testing, this context manager is likely required to properly disable LoRA adapters during calibration. Consider either removing the TODO or clarifying why it might not be needed in the future.


340-345: Add memory checks before moving model to GPU.

The self.model.cuda() call at line 339 could fail for large models that exceed GPU memory. Consider adding a try-catch or memory availability check, especially since the comment mentions that "memory efficient loading doesn't work" for QAD.

Apply this diff to add error handling:

-        self.model.cuda()
+        try:
+            self.model.cuda()
+        except torch.cuda.OutOfMemoryError as e:
+            raise RuntimeError(
+                "Failed to move model to GPU. QAD requires the entire model to fit in GPU memory. "
+                "Consider using QAT instead or reducing model/batch size."
+            ) from e

406-410: Add error handling for checkpoint loading failures.

The dist_cp.load_state_dict call could fail if the checkpoint is corrupted or incompatible. Consider adding error handling to provide a clearer error message.

Apply this diff to add error handling:

-    dist_cp.load_state_dict(
-        state_dict=state_dict,
-        storage_reader=dist_cp.FileSystemReader(sharded_model_path),
-        no_dist=True,
-    )
+    try:
+        dist_cp.load_state_dict(
+            state_dict=state_dict,
+            storage_reader=dist_cp.FileSystemReader(sharded_model_path),
+            no_dist=True,
+        )
+    except Exception as e:
+        print_rank_0(f"Failed to load sharded checkpoint from {sharded_model_path}: {e}")
+        raise
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 77c48fe and 53a4430.

📒 Files selected for processing (4)
  • examples/llm_qat/convert_sharded_ckpt.py (0 hunks)
  • examples/llm_qat/launch.sh (4 hunks)
  • examples/llm_qat/utils.py (1 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (7 hunks)
💤 Files with no reviewable changes (1)
  • examples/llm_qat/convert_sharded_ckpt.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/llm_qat/utils.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T20:46:28.997Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:28.997Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (5)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • TensorQuantizer (62-1182)
  • collect (1166-1176)
modelopt/torch/quantization/utils.py (5)
  • calibrate_with_adapters (275-286)
  • disable_lora_quantizers_in_config (289-296)
  • get_quantizer_state_dict (446-456)
  • is_quantized (239-243)
  • set_quantizer_state_dict (459-466)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/opt/conversion.py (4)
  • modelopt_state (444-486)
  • save (489-507)
  • state_dict (130-132)
  • load_state_dict (139-161)
modelopt/torch/opt/plugins/huggingface.py (1)
  • enable_huggingface_checkpointing (127-162)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (7)
examples/llm_qat/launch.sh (2)

110-137: Backend case: LGTM with minor polish.

Case handling is clear, includes fsdp↔fsdp1 alias and deprecation note for fsdp2.

Please confirm accelerate_config/{fsdp1.yaml,fsdp2.yaml,ddp.yaml,deepspeed.yaml} exist in this PR to prevent runtime failures.


138-146: FSDP memory-efficient loading disable under distill: scoped correctly.

Scoping to fsdp1/fsdp2 only is appropriate.

modelopt/torch/quantization/plugins/transformers_trainer.py (5)

182-197: LGTM! Clean separation of state and weights for FSDP2 compatibility.

The save method now correctly stores both the filtered ModelOpt state dict and quantizer weights separately, which fixes the FSDP2 workflow issues mentioned in the PR objectives.


203-231: LGTM! Comprehensive quantization flow with dataset validation.

The quantization method properly validates dataset availability, creates a calibration subset, and saves the state after quantization. The integration with calibrate_with_adapters ensures LoRA adapters are correctly handled during calibration.


384-416: LGTM! Well-structured checkpoint conversion utility.

The new convert_sharded_model_to_hf_format function properly handles the conversion from FSDP sharded checkpoints to HuggingFace format, addressing the checkpoint compatibility issues mentioned in the PR objectives.


250-255: Verify FSDP2 eval-only workaround doesn't affect training mode.

Search found the dummy-optimizer hack only at modelopt/torch/quantization/plugins/transformers_trainer.py:250-255; ensure overwriting self.model via accelerator.prepare(dummy_optimizer) in the eval-only branch cannot leave the model in an FSDP-wrapped/prepared state that would interfere with later training — use a locally prepared model for evaluation or re-prepare the model with the real training optimizer before any subsequent training.


173-180: Ensure _patch_accelerate_for_fsdp2_fix() runs before any Accelerate prepare calls

The method monkey‑patches accelerator.prepare to hide quantizer buffers — it must execute before any call to .prepare() (including calls from parent class init or plugins). Verify now by inspecting modelopt/torch/quantization/plugins/transformers_trainer.py (around lines 173–180) and confirming the patch call happens before any super().init/plugin initialization that could invoke accelerator.prepare; if a prepare() can run earlier, move the patch to the very start of init (or earlier) so it always precedes any .prepare() invocation. Automated sandbox checks failed; please run a repo search for ".prepare(" and confirm call order.

@realAsma realAsma force-pushed the asma/fix_fsdp2_state_fix branch from 069b244 to 7551732 Compare September 16, 2025 14:09
@realAsma realAsma force-pushed the asma/fix_fsdp2_state_fix branch from 7551732 to b605d8c Compare September 16, 2025 18:28
print_rank_0("Setting state_dict_type to FULL_STATE_DICT for final checkpoint save.")
self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
outputs = super().save_model(*args, **kwargs)
torch.distributed.barrier()
Copy link
Contributor Author

@realAsma realAsma Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kinjalpatel27 Converting FSDP2 to use FULL_STATE_DICT works with both QAT and QAD on larger models after inserting this distributed barrier (the problem was while rank0 was still saving the model, the other ranks had exited the program).

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/llm_qat/launch.sh (1)

63-66: Guard division by zero when no GPUs are present

GPU_COUNT can be 0, causing DEFAULT_SAVE_STEPS=$((192 / 0)) to fail.

Apply:

-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
+GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())" 2>/dev/null || echo 0)
+# Calculate save_steps (fallback to 192 on CPU/no CUDA)
+if (( GPU_COUNT > 0 )); then
+  DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
+else
+  >&2 echo "Warning: No GPUs detected; using DEFAULT_SAVE_STEPS=192"
+  DEFAULT_SAVE_STEPS=192
+fi
♻️ Duplicate comments (1)
examples/llm_qat/launch.sh (1)

21-25: Fix arg parsing: avoid global shift + validate missing values

Current parse_value shifts the global positional params (double-shifts with the caller) and accepts a following flag as a value. Replace with a non-shifting, validating version.

Apply:

-# Helper function to parse a single argument value
-parse_value() {
-    if [[ "$1" != *=* ]]; then shift; fi
-    echo "${1#*=}"
-}
+# Helper: extract value from "--opt=value" or "--opt value"; fail if missing (no global shift)
+parse_value() {
+  local first="$1"
+  local next="${2-}"
+  if [[ "$first" == *=* ]]; then
+    echo "${first#*=}"
+    return 0
+  fi
+  if [[ -z "$next" || "$next" == --* || "$next" == -* ]]; then
+    >&2 echo "Error: Missing value for option '$first'"
+    exit 2
+  fi
+  echo "$next"
+}
🧹 Nitpick comments (18)
examples/llm_qat/launch.sh (6)

29-57: Correct invalid-arg error message to show the flag, not the value

Printing ${1#*=} hides the flag when using --opt=value.

Apply:

-      >&2 printf "Error: Invalid argument ${1#*=}\n"
+      >&2 printf "Error: Invalid argument %s\n" "$1"

88-97: Quote tests; prefer [[ -n ]] to avoid word-splitting and edge cases

Unquoted vars may misbehave when empty or starting with '-'.

Apply:

-if [ -z $QUANT_CFG ]; then
+if [[ -z "$QUANT_CFG" ]]; then
   QUANT_ARGS=""
 else
   QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE"
 fi

 OPTIONAL_ARGS=""
-if [ ! -z $MAX_STEPS ]; then
+if [[ -n "$MAX_STEPS" ]]; then
   OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS"
 fi

99-108: Warn when COMPRESS overrides user-selected BACKEND

Currently COMPRESS silently forces ddp. Emit an explicit notice.

Apply:

 # if compress is true, set backend to ddp
 if [[ "${COMPRESS,,}" == "true" ]]; then
-  BACKEND="ddp"
+  if [[ "${BACKEND,,}" != "ddp" ]]; then
+    echo "Info: --compress enabled; overriding backend '$BACKEND' -> 'ddp'"
+  fi
+  BACKEND="ddp"
 fi

148-179: Quote high-risk args in CMD to handle spaces safely

Model path, output dir, and dataset may contain spaces; quote them.

Apply:

-    --model_name_or_path $MODEL \
+    --model_name_or_path "$MODEL" \
@@
-    --output_dir $OUTPUT_DIR \
+    --output_dir "$OUTPUT_DIR" \
-    --dataset $DATASET \
+    --dataset "$DATASET" \

Optional follow-up: build CMD as an array to avoid word-splitting entirely.


181-183: Ensure timing prints even on failure

Use a trap so elapsed time is reported when the run exits early.

Apply:

-start_time=$(date +%s)
-sh -c "$CMD"
-echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
+start_time=$(date +%s)
+trap 'echo "Total time taken: $(( $(date +%s) - start_time )) seconds"' EXIT
+sh -c "$CMD"

60-60: Gate shell tracing behind DEBUG

set -x can leak sensitive info; make it opt-in.

Apply:

-set -x
+[[ "${DEBUG:-0}" == "1" ]] && set -x
modelopt/torch/quantization/conversion.py (2)

115-123: Avoid recomputing quantizer_state(model) and tighten mismatch reporting.

Compute it once to reduce traversal overhead and provide clearer errors.

-    quantizer_state_dict = metadata["quantizer_state"]
-    unmatched_keys = quantizer_state_dict.keys() - quantizer_state(model).keys()
-    extra_keys = quantizer_state(model).keys() - quantizer_state_dict.keys()
+    quantizer_state_dict = metadata["quantizer_state"]
+    current_state = quantizer_state(model)
+    unmatched_keys = quantizer_state_dict.keys() - current_state.keys()
+    extra_keys = current_state.keys() - quantizer_state_dict.keys()

Also applies to: 124-133


166-172: Guard against potential unwrapped-name collisions.

If two wrapped module paths normalize to the same unwrapped key, one state can overwrite the other silently. Consider asserting uniqueness (debug-only) or logging a warning.

 def quantizer_state(model: nn.Module) -> dict[str, Any]:
     """Returns the quantizer state dict describing the quantizer states in the model."""
-    return {
+    state = {
         get_unwrapped_name(n, model): m.get_modelopt_state()
         for n, m in model.named_modules()
         if isinstance(m, (TensorQuantizer, SequentialQuantizer))
     }
+    # Optional: debug safeguard
+    # assert len(state) == len({get_unwrapped_name(n, model) for n, m in model.named_modules()
+    #                           if isinstance(m, (TensorQuantizer, SequentialQuantizer))}), \
+    #     "Duplicate unwrapped quantizer names detected."
+    return state
modelopt/torch/quantization/plugins/transformers_trainer.py (10)

163-170: Early restore/save path looks good; ensure output_dir exists.

Create the directory before save/restore to avoid surprises on fresh runs.

-        self._modelopt_state_path = os.path.join(self.args.output_dir, "modelopt_state_train.pth")
+        self._modelopt_state_path = os.path.join(self.args.output_dir, "modelopt_state_train.pth")
+        os.makedirs(self.args.output_dir, exist_ok=True)

190-196: Guard when modelopt_state_weights is absent.

Older checkpoints may lack weights; avoid calling set_quantizer_state_dict with None.

-        modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
+        modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
         restore_from_modelopt_state(self.model, modelopt_state)
-        set_quantizer_state_dict(self.model, modelopt_weights)
+        if modelopt_weights is not None:
+            set_quantizer_state_dict(self.model, modelopt_weights)

197-216: Calibration loop: limit tqdm to main process.

Prevents N progress bars under DDP/FSDP.

-        def forward_loop(model):
-            for batch in tqdm(data_loader, desc="Calibrating"):
+        def forward_loop(model):
+            for batch in tqdm(data_loader, desc="Calibrating",
+                              disable=not self.accelerator.is_main_process):
                 batch = self._prepare_inputs(batch)
                 # Important: We should forward pass using the unwrapped model
                 # mtq.quantize will unwrap the model pass the unwrapped model to the forward_loop
                 self.model(**batch)

212-213: Typo in TODO.

“adpaters” → “adapters”.


217-226: Emptying CUDA cache: gate on CUDA availability.

Avoids unnecessary call on CPU-only nodes.

-        torch.cuda.empty_cache()
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()

242-249: Eval-only FSDP2 hack: OK; add small safeguard.

If model has no parameters (edge adapters), next(self.model.parameters()) raises StopIteration.

-            dummy_optimizer = torch.optim.SGD([next(self.model.parameters())], lr=0.0)
+            first_param = next(self.model.parameters(), None)
+            if first_param is None:
+                return super().evaluate(*args, **kwargs)
+            dummy_optimizer = torch.optim.SGD([first_param], lr=0.0)

259-274: Switching to FULL_STATE_DICT: consider restoring previous setting after save.

Prevents persistent side-effects if more saves follow.

-            self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
-            outputs = super().save_model(*args, **kwargs)
+            prev = self.accelerator.state.fsdp_plugin.state_dict_type
+            self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
+            try:
+                outputs = super().save_model(*args, **kwargs)
+            finally:
+                self.accelerator.state.fsdp_plugin.set_state_dict_type(prev)

275-307: Patch safety: avoid double-wrapping and restore non-persistent buffer set in-place.

  • Don’t overwrite _original_prepare if already patched.
  • Restore _non_persistent_buffers_set contents without reassigning the set (prevents reference invalidation).
-        def _modelopt_prepare(self, *args, **kwargs):
+        def _modelopt_prepare(self, *args, **kwargs):
             if not self.is_fsdp2:
                 return self._original_prepare(*args, **kwargs)
 
-            model = next((obj for obj in args if isinstance(obj, torch.nn.Module)), None)
+            model = next((obj for obj in args if isinstance(obj, torch.nn.Module)), None)
+            if model is None:
+                model = next((obj for obj in kwargs.values() if isinstance(obj, torch.nn.Module)), None)
             if model is None:
                 return self._original_prepare(*args, **kwargs)
 
             tq_og_non_prsist_buffers = {}
             for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)):
-                tq.to_empty(device=self.device)
                 tq_og_non_prsist_buffers[tq] = tq._non_persistent_buffers_set.copy()
                 tq._non_persistent_buffers_set.update(tq._buffers.keys())
 
             outputs = self._original_prepare(*args, **kwargs)
 
             for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)):
-                tq._non_persistent_buffers_set.clear()
-                tq._non_persistent_buffers_set = tq_og_non_prsist_buffers[tq]
+                tq._non_persistent_buffers_set.clear()
+                tq._non_persistent_buffers_set.update(tq_og_non_prsist_buffers[tq])
 
             return outputs
 
-        self.accelerator._original_prepare = self.accelerator.prepare
+        if getattr(self.accelerator, "_original_prepare", None) is None:
+            self.accelerator._original_prepare = self.accelerator.prepare
         self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)

291-296: Re-evaluate calling to_empty() on TensorQuantizer during prepare.

It can drop/restage buffers; you’re already hiding them via _non_persistent_buffers_set. If you keep it, gate it to meta-device tensors only.


366-375: State-dict type check by string match is brittle.

Prefer comparing the enum/constant directly if available.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7551732 and b605d8c.

📒 Files selected for processing (20)
  • examples/llm_qat/README.md (0 hunks)
  • examples/llm_qat/accelerate_config/deepspeed.yaml (1 hunks)
  • examples/llm_qat/accelerate_config/fsdp1.yaml (1 hunks)
  • examples/llm_qat/convert_sharded_ckpt.py (0 hunks)
  • examples/llm_qat/launch.sh (4 hunks)
  • examples/llm_qat/llama_factory/launch_llamafactory.sh (0 hunks)
  • examples/llm_qat/main.py (2 hunks)
  • examples/llm_qat/simple_qat_train.py (3 hunks)
  • examples/llm_qat/utils.py (1 hunks)
  • modelopt/torch/opt/conversion.py (2 hunks)
  • modelopt/torch/opt/dynamic.py (1 hunks)
  • modelopt/torch/opt/plugins/peft.py (2 hunks)
  • modelopt/torch/quantization/calib/histogram.py (1 hunks)
  • modelopt/torch/quantization/conversion.py (2 hunks)
  • modelopt/torch/quantization/nn/modules/quant_module.py (1 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (5 hunks)
  • modelopt/torch/quantization/utils.py (2 hunks)
  • modelopt/torch/utils/network.py (3 hunks)
  • tests/_test_utils/examples/run_command.py (1 hunks)
  • tests/examples/llm_qat/test_llm_qat.py (3 hunks)
💤 Files with no reviewable changes (3)
  • examples/llm_qat/README.md
  • examples/llm_qat/convert_sharded_ckpt.py
  • examples/llm_qat/llama_factory/launch_llamafactory.sh
🚧 Files skipped from review as they are similar to previous changes (14)
  • examples/llm_qat/accelerate_config/fsdp1.yaml
  • tests/examples/llm_qat/test_llm_qat.py
  • modelopt/torch/opt/dynamic.py
  • modelopt/torch/quantization/nn/modules/quant_module.py
  • examples/llm_qat/utils.py
  • modelopt/torch/opt/conversion.py
  • modelopt/torch/quantization/calib/histogram.py
  • examples/llm_qat/main.py
  • examples/llm_qat/simple_qat_train.py
  • modelopt/torch/opt/plugins/peft.py
  • tests/_test_utils/examples/run_command.py
  • modelopt/torch/quantization/utils.py
  • examples/llm_qat/accelerate_config/deepspeed.yaml
  • modelopt/torch/utils/network.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T20:46:28.997Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:28.997Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.

Applied to files:

  • modelopt/torch/quantization/conversion.py
  • modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (2)
modelopt/torch/quantization/conversion.py (3)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • set_from_modelopt_state (1122-1140)
  • get_modelopt_state (1105-1120)
  • get_modelopt_state (1246-1248)
modelopt/torch/quantization/nn/modules/quant_module.py (1)
  • QuantModule (37-114)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • TensorQuantizer (62-1182)
  • collect (1166-1176)
modelopt/torch/quantization/utils.py (5)
  • calibrate_with_adapters (275-286)
  • disable_lora_quantizers_in_config (289-296)
  • get_quantizer_state_dict (446-456)
  • is_quantized (239-243)
  • set_quantizer_state_dict (459-466)
modelopt/torch/opt/conversion.py (3)
  • modelopt_state (444-486)
  • save (489-507)
  • restore_from_modelopt_state (510-567)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)

171-189: Barrier + state filtering LGTM; minor: save blob shape documented?

Consider documenting the saved schema (keys) for future compatibility/migrations.


230-241: On-demand quantization trigger LGTM.

The gating is correct and minimizes redundant work.


335-335: Quantize-before-adding LoRA in QAD differs from QAT.

Confirm intent: QAT adds LoRA before quantization, QAD adds after. Verify LoRA quantizers are correctly disabled for QAD.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
modelopt/torch/quantization/plugins/transformers_trainer.py (2)

197-204: LGTM: dataset selection bug fixed.

Now selects the available dataset before computing length; avoids NoneType len() error.


275-306: Avoid recursive accelerate.prepare patching (potential infinite recursion).

On repeated calls, _original_prepare gets overwritten with the already‑patched prepare, making the wrapper call itself. Guard the assignment and add a one‑time patch flag.

     def _patch_accelerate_for_fsdp2_fix(self):
@@
-        def _modelopt_prepare(self, *args, **kwargs):
+        def _modelopt_prepare(self, *args, **kwargs):
             if not self.is_fsdp2:
                 return self._original_prepare(*args, **kwargs)
@@
             return outputs
@@
-        self.accelerator._original_prepare = self.accelerator.prepare
-        self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)
+        # Patch only once and keep a stable pointer to the original prepare
+        if not getattr(self.accelerator, "_modelopt_prepare_patched", False):
+            if getattr(self.accelerator, "_original_prepare", None) is None:
+                self.accelerator._original_prepare = self.accelerator.prepare
+            self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)
+            self.accelerator._modelopt_prepare_patched = True
🧹 Nitpick comments (4)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)

291-296: Nit: fix variable name and keep set object stable.

Typo in variable name and you’re already restoring contents in-place (good). Rename for clarity.

-            tq_og_non_prsist_buffers = {}
+            tq_orig_non_persist_buffers = {}
@@
-                tq_og_non_prsist_buffers[tq] = tq._non_persistent_buffers_set.copy()
+                tq_orig_non_persist_buffers[tq] = tq._non_persistent_buffers_set.copy()
@@
-                tq._non_persistent_buffers_set.update(tq_og_non_prsist_buffers[tq])
+                tq._non_persistent_buffers_set.update(tq_orig_non_persist_buffers[tq])

Also applies to: 299-302


205-211: Calibrate under no_grad/inference_mode to cut memory/overhead.

Forward passes for calibration don’t need grads.

-        def forward_loop(model):
-            for batch in tqdm(data_loader, desc="Calibrating"):
-                batch = self._prepare_inputs(batch)
-                # Important: We should forward pass using the unwrapped model
-                # mtq.quantize will unwrap the model pass the unwrapped model to the forward_loop
-                self.model(**batch)
+        def forward_loop(model):
+            with torch.inference_mode():
+                for batch in tqdm(data_loader, desc="Calibrating"):
+                    batch = self._prepare_inputs(batch)
+                    # Intentionally use self.model per HF/ModelOpt integration contract
+                    self.model(**batch)

212-216: Is calibrate_with_adapters needed? Gate or remove if redundant.

If LoRA disabling isn’t required, consider removing or gating via a flag to avoid surprising behavior.


242-249: Eval‑only FSDP2 prepare: guard models without parameters.

next(self.model.parameters()) will raise on param‑less models. Use a safe fallback.

-            dummy_optimizer = torch.optim.SGD([next(self.model.parameters())], lr=0.0)
+            first_param = next(iter(self.model.parameters()), None)
+            if first_param is None:
+                first_param = torch.nn.Parameter(torch.zeros(1, device=self.model.device))
+            dummy_optimizer = torch.optim.SGD([first_param], lr=0.0)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b605d8c and 032a4bf.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (5 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T20:46:28.997Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:28.997Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • TensorQuantizer (62-1182)
  • collect (1166-1176)
modelopt/torch/quantization/utils.py (5)
  • calibrate_with_adapters (275-286)
  • disable_lora_quantizers_in_config (289-296)
  • get_quantizer_state_dict (446-456)
  • is_quantized (239-243)
  • set_quantizer_state_dict (459-466)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)

217-226: LGTM: post‑calibration housekeeping.

Compression optionality, GC, CUDA cache clear, and persisting state look good.


230-235: LGTM: on‑demand quantization before first train step.

Simple and avoids repeated work once is_quantized() is true.


163-170: Startup restore/save flow: good ordering.

Restoring existing ModelOpt state before any weight loading and persisting post‑quantization aligns with the FSDP2/QLoRA fix.

Please confirm that all entry points which may construct a Trainer instance do so before any external weight loading occurs for FSDP2 paths.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)

302-302: Note: prior discussion on calling to_empty() stands.

You previously confirmed this is intended; leaving as-is.

🧹 Nitpick comments (9)
modelopt/torch/quantization/plugins/transformers_trainer.py (9)

49-55: Fix typos in docstring (clarity/polish).

“This classes” → “This class”; “taining” → “training”.

-    """Quantization arguments for quantization aware training.
+    """Quantization arguments for quantization-aware training.
@@
-    This classes is intended to be used with ModelOpt's QAT/QAD trainers for HuggingFace models.
+    This class is intended to be used with ModelOpt's QAT/QAD trainers for HuggingFace models.
@@
-    from the command line to the taining script.
+    from the command line to the training script.

92-101: Make AWQ/SmoothQuant detection robust for dict or QuantizeConfig.

check_awq_smoothquant assumes dict; support QuantizeConfig or objects with .algorithm.

-def check_awq_smoothquant(quant_cfg):
+def check_awq_smoothquant(quant_cfg):
@@
-    algorithm = quant_cfg.get("algorithm", {})
+    if hasattr(quant_cfg, "algorithm"):
+        algorithm = getattr(quant_cfg, "algorithm") or {}
+    elif isinstance(quant_cfg, dict):
+        algorithm = quant_cfg.get("algorithm", {}) or {}
+    else:
+        algorithm = {}

Also applies to: 103-115


163-170: Early FS-DP2 patch + state restore/save ordering looks right.

Restoring ModelOpt states before weights and saving when already-quantized is correct for FSDP2/QLoRA flows. Consider ensuring output_dir exists.

-        self._modelopt_state_path = os.path.join(self.args.output_dir, "modelopt_state_train.pth")
+        os.makedirs(self.args.output_dir, exist_ok=True)
+        self._modelopt_state_path = os.path.join(self.args.output_dir, "modelopt_state_train.pth")

171-189: Save path: include barrier only when initialized; filtering KD/export state is fine.

Minor nit: optional guard for availability; otherwise LGTM.

-        if torch.distributed.is_initialized():
+        if torch.distributed.is_available() and torch.distributed.is_initialized():
             torch.distributed.barrier()

222-227: Guard CUDA cache emptying.

Avoid calling empty_cache on CPU-only builds.

-        torch.cuda.empty_cache()
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()

243-251: Handle models with zero trainable parameters in eval-only FSDP2.

next(self.model.parameters()) can raise StopIteration for fully-frozen/adapter-only models.

-            dummy_optimizer = torch.optim.SGD([next(self.model.parameters())], lr=0.0)
+            first_param = next(self.model.parameters(), None)
+            if first_param is None:
+                dummy = torch.nn.Parameter(torch.zeros(1, device=self.accelerator.device))
+                dummy_optimizer = torch.optim.SGD([dummy], lr=0.0)
+            else:
+                dummy_optimizer = torch.optim.SGD([first_param], lr=0.0)

261-283: Save path: compare fsdp state_dict_type via str() for robustness.

Some plugins use enums/objects; string compare avoids false negatives.

-            and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT"
+            and str(self.accelerator.state.fsdp_plugin.state_dict_type) != "FULL_STATE_DICT"

284-316: Make Accelerate prepare patch idempotent to avoid recursion on re‑patches.

Guard _original_prepare assignment so multiple trainers don’t capture the wrapped function and recurse.

-        self.accelerator._original_prepare = self.accelerator.prepare
-        self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)
+        if getattr(self.accelerator, "_original_prepare", None) is None:
+            self.accelerator._original_prepare = self.accelerator.prepare
+        self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)

300-311: Nit: fix variable name and keep in‑place restore of non‑persistent buffer sets.

Spelling and clarity; logic is good.

-            tq_og_non_prsist_buffers = {}
+            tq_orig_non_persistent_buffers = {}
@@
-                tq_og_non_prsist_buffers[tq] = tq._non_persistent_buffers_set.copy()
+                tq_orig_non_persistent_buffers[tq] = tq._non_persistent_buffers_set.copy()
@@
-                tq._non_persistent_buffers_set.update(tq_og_non_prsist_buffers[tq])
+                tq._non_persistent_buffers_set.update(tq_orig_non_persistent_buffers[tq])
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 032a4bf and a9d4d93.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (5 hunks)
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-15T20:46:28.997Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:28.997Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-16T20:14:34.725Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:191-191
Timestamp: 2025-09-16T20:14:34.725Z
Learning: The TensorRT-Model-Optimizer project only supports PyTorch >= 2.6, so using the `weights_only` parameter in torch.load calls is acceptable and doesn't require backward compatibility handling.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
  • TensorQuantizer (62-1182)
  • collect (1166-1176)
modelopt/torch/quantization/utils.py (4)
  • calibrate_with_adapters (275-286)
  • get_quantizer_state_dict (446-456)
  • is_quantized (239-243)
  • set_quantizer_state_dict (459-466)
modelopt/torch/opt/conversion.py (5)
  • modelopt_state (444-486)
  • save (489-507)
  • restore_from_modelopt_state (510-567)
  • ModeloptStateManager (63-311)
  • is_converted (102-127)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
🔇 Additional comments (5)
modelopt/torch/quantization/plugins/transformers_trainer.py (5)

145-157: LoRA adapter handling before quantization looks good.


190-197: Restore logic is correct and backward-safe for missing weights.


198-205: Dataset selection ordering fix is correct.

Using a Subset of first calib_size examples is fine; using eval dataloader for a train dataset is acceptable for calibration, but verify collator differences don’t affect stats.


213-217: Revisit calibrate_with_adapters necessity (and ensure it’s a contextmanager).

If no functional benefit, consider removing to simplify; also confirm it’s decorated with @contextmanager.


229-230: Quant summary only on main process — good practice.

…; Cleaned QATTrainer

Signed-off-by: realAsma <[email protected]>

bug fix

Signed-off-by: realAsma <[email protected]>

Fixed full_state_dict hang by inserting barrier

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>

minor

Signed-off-by: realAsma <[email protected]>
@realAsma realAsma force-pushed the asma/fix_fsdp2_state_fix branch from a9d4d93 to c4bba4f Compare September 17, 2025 00:03
@realAsma realAsma enabled auto-merge (squash) September 17, 2025 00:05
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/llm_qat/launch.sh (2)

63-66: Guard division by zero when no GPUs are visible.

On CPU‑only nodes torch.cuda.device_count() returns 0 and the arithmetic expansion fails.

-GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
+GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
+if [[ "$GPU_COUNT" -le 0 ]]; then
+  >&2 echo "Warning: No GPUs detected; defaulting save_steps to 192"
+  GPU_COUNT=1
+fi
+DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))

148-152: Move script-specific args after main.py.

Options like --fsdp_transformer_layer_cls_to_wrap are for the training script (main.py), not for accelerate; place $FSDP_ARGS after main.py to avoid "unrecognized argument" errors.

File: examples/llm_qat/launch.sh (lines 148-152)

-CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
-    main.py \
+CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE \
+    main.py $FSDP_ARGS \
♻️ Duplicate comments (2)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)

314-315: Make accelerate.prepare patch idempotent to avoid recursion on repeated patching.

If init runs twice on the same Accelerator, _original_prepare will be overwritten with the wrapper, causing infinite recursion.

Apply:

-        self.accelerator._original_prepare = self.accelerator.prepare
-        self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)
+        if getattr(self.accelerator, "_original_prepare", None) is None:
+            self.accelerator._original_prepare = self.accelerator.prepare
+        self.accelerator.prepare = types.MethodType(_modelopt_prepare, self.accelerator)
examples/llm_qat/launch.sh (1)

21-25: Harden parse_value to prevent option bleed‑through (missing value).

Current logic can treat the next flag as a value. Fail fast when the value is missing.

-# Helper function to parse a single argument value
-parse_value() {
-    if [[ "$1" != *=* ]]; then shift; fi
-    echo "${1#*=}"
-}
+# Helper: extract value from "--opt=value" or "--opt value"; error if missing
+parse_value() {
+  local first="$1"; shift || true
+  if [[ "$first" == *=* ]]; then
+    echo "${first#*=}"
+    return 0
+  fi
+  local next="${1-}"
+  if [[ -z "$next" || "$next" == --* || "$next" == -* ]]; then
+    >&2 echo "Error: Missing value for option '$first'"
+    exit 2
+  fi
+  echo "$next"
+}
🧹 Nitpick comments (10)
modelopt/torch/utils/network.py (1)

73-77: Good centralization of wrapper support; consider aligning is_parallel with SUPPORTED_WRAPPERS.

Nice consolidation, including optional DeepSpeed. For consistency, update is_parallel (Line 90-93) to derive from SUPPORTED_WRAPPERS so FSDP/DS are also detected.

Apply:

 def is_parallel(model: nn.Module) -> bool:
     """Check if a PyTorch model is parallelized."""
-    return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
+    return type(model) in SUPPORTED_WRAPPERS

Also applies to: 79-86

modelopt/torch/opt/plugins/peft.py (1)

91-94: Avoid KeyError if a quantizer key is missing during load.

Guard the lookup so older/newer checkpoints don’t fail when module sets differ.

Apply:

-            if isinstance(module, TensorQuantizer):
-                module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name, self)])
+            if isinstance(module, TensorQuantizer):
+                key = get_unwrapped_name(name, self)
+                state = quantizer_state_dict.get(key)
+                if state is not None:
+                    module.load_state_dict(state)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)

206-212: Minor: prefer is_main_process for tqdm gating.

should_save often aligns with main rank but is semantically about checkpointing. Using accelerator.is_main_process is clearer.

Apply:

-            for batch in tqdm(data_loader, desc="Calibrating", disable=not self.args.should_save):
+            for batch in tqdm(
+                data_loader, desc="Calibrating", disable=not self.accelerator.is_main_process
+            ):

268-279: Restore FSDP state_dict_type even on exceptions during save.

Use try/finally to avoid leaking FULL_STATE_DICT if save_model raises.

Apply:

-            original_type = self.accelerator.state.fsdp_plugin.state_dict_type
-            self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
-            outputs = super().save_model(*args, **kwargs)
-            if torch.distributed.is_initialized():
-                torch.distributed.barrier()
-            if mto.ModeloptStateManager.is_converted(self.accelerator.unwrap_model(self.model)):
-                print_rank_0(
-                    "Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the "
-                    "model. See https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.opt.plugins.huggingface.html#modelopt.torch.opt.plugins.huggingface.enable_huggingface_checkpointing"
-                )
-            self.accelerator.state.fsdp_plugin.set_state_dict_type(original_type)
+            original_type = self.accelerator.state.fsdp_plugin.state_dict_type
+            self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
+            try:
+                outputs = super().save_model(*args, **kwargs)
+                if torch.distributed.is_initialized():
+                    torch.distributed.barrier()
+                if mto.ModeloptStateManager.is_converted(self.accelerator.unwrap_model(self.model)):
+                    print_rank_0(
+                        "Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the "
+                        "model. See https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.opt.plugins.huggingface.html#modelopt.torch.opt.plugins.huggingface.enable_huggingface_checkpointing"
+                    )
+            finally:
+                self.accelerator.state.fsdp_plugin.set_state_dict_type(original_type)

300-312: Ensure buffer‑hiding is reverted if prepare() throws.

Wrap the original prepare call in try/finally so _non_persistent_buffers_set is always restored.

Apply:

-            outputs = self._original_prepare(*args, **kwargs)
-
-            for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)):
-                tq._non_persistent_buffers_set.clear()
-                tq._non_persistent_buffers_set.update(tq_og_non_prsist_buffers[tq])
-
-            return outputs
+            try:
+                return self._original_prepare(*args, **kwargs)
+            finally:
+                for tq in (m for m in model.modules() if isinstance(m, TensorQuantizer)):
+                    tq._non_persistent_buffers_set.clear()
+                    tq._non_persistent_buffers_set.update(tq_og_non_prsist_buffers[tq])
examples/llm_qat/accelerate_config/fsdp1.yaml (1)

7-7: Enabling activation checkpointing: confirm intended trade‑offs and consistency with training args.

Turning on fsdp_activation_checkpointing: true increases recompute and reduces memory. Ensure this is intentional for fsdp1 runs and consistent with gradient checkpointing flags in launch.sh (which only enables --gradient_checkpointing for ddp/deepspeed). If fsdp1 also requires model‑level GC, consider wiring that similarly.

examples/llm_qat/launch.sh (4)

53-55: Fix invalid‑arg error message to print the flag itself.

${1#*=} trims up to =, producing odd output for space‑separated flags.

-      >&2 printf "Error: Invalid argument ${1#*=}\n"
+      >&2 printf "Error: Invalid argument %s\n" "$1"

88-97: Quote variable checks to avoid word‑splitting and unset pitfalls.

-if [ -z $QUANT_CFG ]; then
+if [ -z "${QUANT_CFG:-}" ]; then
   QUANT_ARGS=""
 else
   QUANT_ARGS="--quant_cfg $QUANT_CFG --calib_size $CALIB_SIZE"
 fi

-OPTIONAL_ARGS=""
-if [ ! -z $MAX_STEPS ]; then
+OPTIONAL_ARGS=""
+if [ -n "${MAX_STEPS:-}" ]; then
   OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS"
 fi

60-60: Gate tracing behind a DEBUG flag to reduce noise.

-set -x
+[[ "${DEBUG:-0}" == "1" ]] && set -x

148-179: Optional: avoid sh -c and word‑splitting by using an argv array.

Safer with paths containing spaces and avoids double parsing.

-start_time=$(date +%s)
-sh -c "$CMD"
+start_time=$(date +%s)
+eval "$CMD"
 echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"

Or better:

# Build an array instead of a string; then: "${CMD[@]}"
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a9d4d93 and c4bba4f.

📒 Files selected for processing (20)
  • examples/llm_qat/README.md (0 hunks)
  • examples/llm_qat/accelerate_config/deepspeed.yaml (1 hunks)
  • examples/llm_qat/accelerate_config/fsdp1.yaml (1 hunks)
  • examples/llm_qat/convert_sharded_ckpt.py (0 hunks)
  • examples/llm_qat/launch.sh (4 hunks)
  • examples/llm_qat/llama_factory/launch_llamafactory.sh (0 hunks)
  • examples/llm_qat/main.py (2 hunks)
  • examples/llm_qat/simple_qat_train.py (3 hunks)
  • examples/llm_qat/utils.py (1 hunks)
  • modelopt/torch/opt/conversion.py (2 hunks)
  • modelopt/torch/opt/dynamic.py (1 hunks)
  • modelopt/torch/opt/plugins/peft.py (2 hunks)
  • modelopt/torch/quantization/calib/histogram.py (1 hunks)
  • modelopt/torch/quantization/conversion.py (2 hunks)
  • modelopt/torch/quantization/nn/modules/quant_module.py (1 hunks)
  • modelopt/torch/quantization/plugins/transformers_trainer.py (5 hunks)
  • modelopt/torch/quantization/utils.py (2 hunks)
  • modelopt/torch/utils/network.py (3 hunks)
  • tests/_test_utils/examples/run_command.py (1 hunks)
  • tests/examples/llm_qat/test_llm_qat.py (3 hunks)
💤 Files with no reviewable changes (3)
  • examples/llm_qat/llama_factory/launch_llamafactory.sh
  • examples/llm_qat/convert_sharded_ckpt.py
  • examples/llm_qat/README.md
🚧 Files skipped from review as they are similar to previous changes (10)
  • modelopt/torch/quantization/calib/histogram.py
  • modelopt/torch/quantization/nn/modules/quant_module.py
  • modelopt/torch/opt/dynamic.py
  • tests/examples/llm_qat/test_llm_qat.py
  • tests/_test_utils/examples/run_command.py
  • modelopt/torch/quantization/conversion.py
  • examples/llm_qat/accelerate_config/deepspeed.yaml
  • examples/llm_qat/simple_qat_train.py
  • modelopt/torch/quantization/utils.py
  • examples/llm_qat/main.py
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2025-09-15T20:46:28.997Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:28.997Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.

Applied to files:

  • modelopt/torch/opt/conversion.py
  • modelopt/torch/opt/plugins/peft.py
  • modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-16T21:46:46.318Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-212
Timestamp: 2025-09-16T21:46:46.318Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the mtq.quantize function calls the forward_loop under a no_grad context, so wrapping the forward_loop in inference_mode or no_grad is not needed.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-15T16:40:12.757Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-208
Timestamp: 2025-09-15T16:40:12.757Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the forward_loop function in _quantize_model should use self.model(**batch) rather than the model parameter passed to forward_loop. The model parameter should not be used for the forward pass.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
📚 Learning: 2025-09-16T20:14:34.725Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:191-191
Timestamp: 2025-09-16T20:14:34.725Z
Learning: The TensorRT-Model-Optimizer project only supports PyTorch >= 2.6, so using the `weights_only` parameter in torch.load calls is acceptable and doesn't require backward compatibility handling.

Applied to files:

  • modelopt/torch/quantization/plugins/transformers_trainer.py
🧬 Code graph analysis (3)
modelopt/torch/opt/conversion.py (1)
modelopt/torch/utils/network.py (1)
  • unwrap_model (430-454)
modelopt/torch/opt/plugins/peft.py (2)
modelopt/torch/quantization/utils.py (1)
  • get_quantizer_state_dict (446-456)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/quantization/plugins/transformers_trainer.py (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • TensorQuantizer (61-1139)
  • disable (394-399)
  • collect (1123-1133)
modelopt/torch/quantization/utils.py (5)
  • calibrate_with_adapters (275-286)
  • disable_lora_quantizers_in_config (289-296)
  • get_quantizer_state_dict (446-456)
  • is_quantized (239-243)
  • set_quantizer_state_dict (459-466)
modelopt/torch/opt/conversion.py (5)
  • modelopt_state (444-486)
  • save (489-507)
  • restore_from_modelopt_state (510-567)
  • ModeloptStateManager (63-311)
  • is_converted (102-127)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (9)
examples/llm_qat/utils.py (1)

172-176: LGTM — simple, safe perplexity augmentation.

Handles scalar/tensor eval_loss correctly and returns a float. No issues.

modelopt/torch/utils/network.py (1)

599-612: get_unwrapped_name: DP/DS prefix cleanup + FSDP name normalization — looks correct.

The conditional removeprefix for DP/DS and subsequent _convert_to_wrapped_module_name cover common wrappers well.

modelopt/torch/opt/conversion.py (1)

383-384: Wrapper handling policy is consistent across apply/save/restore.

  • apply_mode/modelopt_state: force unwrapping is appropriate.
  • restore: raising on wrapped models is correct to avoid ambiguous state.

Also applies to: 469-470, 591-592

modelopt/torch/opt/plugins/peft.py (1)

60-65: Centralized quantizer snapshotting — good move.

Using get_quantizer_state_dict() avoids FSDP hangs and keeps logic consistent.

modelopt/torch/quantization/plugins/transformers_trainer.py (3)

165-170: Early restore/save of ModelOpt state — solid.

Restoring on startup and snapshotting when already quantized keeps flows resilient across backends.


171-189: Checkpoint payload layout is sensible; barrier guard present.

Filtered state plus quantizer weights is compact and sufficient. Pre‑save barrier is guarded. LGTM.


190-197: Restore path is correct and backward‑compatible.

weights_only use is fine given PyTorch ≥ 2.6, and the optional weights guard prevents crashes on older blobs.

examples/llm_qat/launch.sh (1)

110-131: Minor: ensure FSDP override is delivered to the script, not Accelerate.

After moving $FSDP_ARGS behind main.py, the distillation‑time override (--fsdp_cpu_ram_efficient_loading False) will reach main.py as intended. No change needed beyond the earlier repositioning; call this out to double‑check downstream parsing.

examples/llm_qat/accelerate_config/fsdp1.yaml (1)

13-13: Confirm FSDP version and correct fsdp_reshard_after_forward usage

File: examples/llm_qat/accelerate_config/fsdp1.yaml (line 13)

fsdp_reshard_after_forward accepts sharding-strategy strings for fsdp_version=1 but is a boolean for fsdp_version=2 — adjust based on your fsdp_version.

If fsdp_version == 2 apply:

-  fsdp_reshard_after_forward: FULL_SHARD
+  fsdp_reshard_after_forward: true
+  fsdp_sharding_strategy: FULL_SHARD

If fsdp_version == 1, no change required.

@realAsma realAsma merged commit 3524732 into main Sep 17, 2025
25 of 33 checks passed
@realAsma realAsma deleted the asma/fix_fsdp2_state_fix branch September 17, 2025 01:43
yeyu-nvidia pushed a commit that referenced this pull request Sep 18, 2025
…specific unitests; (#318)

Signed-off-by: realAsma <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants